From d7a541ea2d67543660bb1c14255afb6c1b999b08 Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Wed, 17 Apr 2024 18:47:30 +0530 Subject: [PATCH 001/287] Fix checkfail in GatherV2 The Op GatherV2 leads to check fail particularly when axis=kint64max. Anything less than or greater than that leads to valid exception. --- tensorflow/core/kernels/gather_op.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 1ff8145688d35e..098cc6866a1d13 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -77,6 +77,9 @@ class GatherOp : public OpKernel { errors::InvalidArgument("axis must be int32 or int64.")); } } + // special case to avoid checkfail when axis = kint64max. + OP_REQUIRES(c, axis < kint64max, + absl::InvalidArgumentError("axis must be less than kint64max")); int64_t min_params_dim = axis < 0 ? -axis : axis + 1; OP_REQUIRES( From 24a9d7b038fa4e87c8d5960ebadd204356243ece Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 May 2024 14:13:20 +0000 Subject: [PATCH 002/287] Merged commit includes the following changes: 637889039 by A. Unique TensorFlower: Remove experimental_adaptive_avx_optimization flag from XNNPACK delegate options It's always on now. -- 637886275 by A. Unique TensorFlower: [XLA:GPU][IndexAnalysis] Use a flag for IsKnownEmpty instead of recomputing every time. Right now, we would try to simplify or compose with indexing maps that have a known empty domain. That's incorrect, but checking if the domain is empty every time is expensive and can be cached. -- 637876088 by A. Unique TensorFlower: Internal config change -- 637864812 by A. Unique TensorFlower: PR #13088: [ROCm] Fix reduce_atomic_min.hlo.test Imported from GitHub PR https://github.com/openxla/xla/pull/13088 Copybara import of the project: -- b241e076198c03fffd8c7e3a6568070ef0223653 by mmakevic : Fix reduce_atomic_min.hlo.test -- f894f1954513019f0ca6890a27e09e0fee9d462e by mmakevic : Remove extra space Merging this change closes #13088 -- 637860531 by A. Unique TensorFlower: Remove xla_gpu_normalize_layouts flag. By now, this is really not experimental anymore. -- 637857834 by A. Unique TensorFlower: Add heuristic for when to treat Gather ops as coalesced. -- 637820064 by A. Unique TensorFlower: compat: Update forward compatibility horizon to 2024-05-28 -- 637820063 by A. Unique TensorFlower: Update GraphDef version to 1876. -- 637756070 by A. Unique TensorFlower: Automated rollback of changelist 636206934. 637674999 by A. Unique TensorFlower: [xla:cpu] Add initial support for Thunk-based execution to CpuCompiler and CpuExecutable Add support for compiling XLA:CPU HloModule to a ThunkSequence instead of a LLVM module and a jit-compiled function. -- 637666734 by A. Unique TensorFlower: Don't fuse inside computations that are already fused. -- 637657345 by A. Unique TensorFlower: Automated rollback of changelist 636208997. 637651034 by A. Unique TensorFlower: Integrate LLVM at llvm/llvm-project@fddf350f9640 Updates LLVM usage to match [fddf350f9640](https://github.com/llvm/llvm-project/commit/fddf350f9640) -- 637639233 by A. Unique TensorFlower: PR #12940: [ROCm] Fix dot_bf16.hlo.test on ROCm Imported from GitHub PR https://github.com/openxla/xla/pull/12940 Added additional params for `hlo_lit_tests` as a workaround, so `mi200.txtpb` would be used in `dot_bf16.hlo.test` for rocm. Copybara import of the project: -- c3bb3a7349266a51ff22a2e18dab0afb6e81bad4 by mmakevic : Have dot_bf16.hlo.test use mi200.txtpb for rocm Merging this change closes #12940 -- 637632492 by A. Unique TensorFlower: PR #13089: Fix reduce_large_row_to_scalar.hlo.test Imported from GitHub PR https://github.com/openxla/xla/pull/13089 Copybara import of the project: -- ae97058c01ca57107a2566a6f190d51f5ad4ca0e by mmakevic : Fix reduce_large_row_to_scalar.hlo.test Merging this change closes #13089 -- 637623329 by A. Unique TensorFlower: Automated rollback of changelist 637594837. 637607386 by A. Unique TensorFlower: Automated rollback of changelist 636926669. 637594837 by A. Unique TensorFlower: [XLA:GPU] Pass CUDA_VERSION explicitly into CudnnFusedConvRewriter. Passing the CuDNN version will be the next step. -- 637580666 by A. Unique TensorFlower: Remove usage of --xla_gpu_enable_triton_hopper in autotuner -- 637578573 by A. Unique TensorFlower: [XLA:GPU] Add documentation about RTVars. -- 637570959 by A. Unique TensorFlower: Update GraphDef version to 1875. -- 637570942 by A. Unique TensorFlower: compat: Update forward compatibility horizon to 2024-05-27 -- 637561798 by A. Unique TensorFlower: PR #12979: [NVIDIA] Fix PGLE for latency estimation of p2p instructions Imported from GitHub PR https://github.com/openxla/xla/pull/12979 PGLE doesn't recognize p2p instruction such as send or recv as async operations. This adds the utility to check if instruction is a p2p communication instruction. Copybara import of the project: -- 469b2d31ff6b0270dda28f8754462681514d0e04 by TJ Xu : fix pgle not recognizing p2p instructions Merging this change closes #12979 -- 637560035 by A. Unique TensorFlower: [xla:gpu] Track loop iteration counter of a WhileThunk in thread local variable -- 637552495 by A. Unique TensorFlower: PR #13056: Use `operator->` with XLA FFI Result Buffers in custom call docs Imported from GitHub PR https://github.com/openxla/xla/pull/13056 Copybara import of the project: -- 7940a1a02a0f93736a88406958edf62488bdbe19 by Andrey Portnoy : Use `operator->` with XLA FFI Result Buffers in custom call docs Merging this change closes #13056 -- 637547404 by A. Unique TensorFlower: PR #13068: Introduce the Blackwell compute capability. Imported from GitHub PR https://github.com/openxla/xla/pull/13068 Introduce the Blackwell compute capability. Future Blackwell-specific changes can be guarded by this capability. Copybara import of the project: -- cc1adebc95166b2d3979cc01de954a1895515ad4 by Dimitris Vardoulakis : Introduce the Blackwell compute capability. Future Blackwell-specific changes can be guarded by this capability. Merging this change closes #13068 -- 637541058 by A. Unique TensorFlower: PR #13061: Add Tirton support for XLA clamp Imported from GitHub PR https://github.com/openxla/xla/pull/13061 Add Triton support for XLA clamp instruction. Clamp is a common instruction found in FP8 fusions, and will be used in cuDNN fusions: This is a fix for perviously rolled-back PR due to internal ir_emitter_triton test failure: https://github.com/openxla/xla/commit/d114eceb0afa4289e1ba4468a0474d2c1ffe4123 cc @sergeykozub @sergachev Copybara import of the project: -- 3496ba2fa86571ab290e0881dd06400c415d80b6 by Elfie Guo : Add Tirton support for XLA clamp. Merging this change closes #13061 -- 637366630 by A. Unique TensorFlower: Update GraphDef version to 1874. -- 637366295 by A. Unique TensorFlower: compat: Update forward compatibility horizon to 2024-05-26 -- 637185396 by A. Unique TensorFlower: Automated Code Change -- 637168744 by A. Unique TensorFlower: Update GraphDef version to 1873. -- 637168421 by A. Unique TensorFlower: compat: Update forward compatibility horizon to 2024-05-25 -- 637166714 by A. Unique TensorFlower: Attempt loading libOpenCL.so before libOpenCL-pixel.so -- 637137789 by A. Unique TensorFlower: feat: Implement hermetic Python version matching system Python version -- 637102058 by A. Unique TensorFlower: [IFRT] Add xla::ifrt::Sharding::IsFullyReplicated() IFRT Sharding type gains `IsFullyReplicated()`, which quickly tells if the sharding represents a fully-replicated sharding. The main motivation is to make full replication information queriable at IFRT shardings and prepare for enabling IFRT implementations to handle full replication directly. There are a preset of rules: * `SingleDeviceSharding` is trivially fully replicated by its definition. * `ConcreteSharding` and `OpaqueSharding` is not fully replicated. They have special cases where it may be fully replicated, but the user is advised to use a more specific sharding type to represent such cases. * `ConcreteEvenSharding` may/may not fully replicated. This is controlled at creation time. * `ShardingParamSharding` and (IFRT) `HloSharding` depend on whether their lower-level sharding represents full replication. `ConcreteEvenSharding` is a noteworthy case where the full replication information does not come from the existing source of the information. This is because the creators of this sharding (e.g., JAX) typically has the information, but the replication information is lost when coercing it into `ConcreteEvenSharding`. This problem will be gradually less problematic once JAX uses a higher-level IFRT sharding type (mainly (IFRT) `HloSharding`) at more places. This change extends the `Sharding` type, but the new method is not used by any existing code. -- 637097325 by A. Unique TensorFlower: Ensure delegates properly delegate models -- 637080761 by A. Unique TensorFlower: Add barrier logs. -- 637070664 by A. Unique TensorFlower: Clean up include and build file -- 637069670 by A. Unique TensorFlower: Use the `LoadedClientGraph`'s copy of `FunctionLibraryDefinition` instead of getting it from the `FallbackState` in the parent `GraphExecutor` -- 637069442 by A. Unique TensorFlower: update doc ref -- 637061122 by A. Unique TensorFlower: Refactor exhaustive testing of unary float32 functions into a library. -- 637046941 by A. Unique TensorFlower: fix profile_util's compatible_with tag typo -- 637028365 by A. Unique TensorFlower: [XLA] Refactor HostOffloader. Change HostOffloader's algorithm for identifying host memory offloading. This approach supports every conceivable host memory offloading pattern (as of today). -- 637023690 by A. Unique TensorFlower: Simplify volumes for docker container in XLA build script -- 637018892 by A. Unique TensorFlower: move flatbuffer_compatibility_test target to tflite compiler -- 637008187 by A. Unique TensorFlower: Add copyright notice to profiler_utils.cc -- 636990162 by A. Unique TensorFlower: Adds a proto profile summary formatter to the TFLite benchmark. Adds a Python script to convert benchmark profile protos to a JSON consumable by the model-explorer. -- 636976463 by A. Unique TensorFlower: Add profiler_util to enable flexibly tpu profiler registration for different purposes -- PiperOrigin-RevId: 637889039 --- tensorflow/compiler/mlir/lite/schema/BUILD | 23 + .../schema/flatbuffer_compatibility_test.cc | 5 +- .../compiler/mlir/lite/schema/schema_v3b.fbs | 1242 +++++++++++ tensorflow/core/kernels/BUILD | 1 - tensorflow/core/kernels/gather_nd_op.cc | 3 +- tensorflow/core/kernels/gather_nd_op.h | 7 +- tensorflow/core/kernels/scatter_nd_op.cc | 63 +- tensorflow/core/kernels/scatter_nd_op.h | 4 +- .../core/kernels/scatter_nd_op_cpu_impl.h | 52 +- .../core/kernels/scatter_nd_op_gpu.cu.cc | 10 +- tensorflow/core/ops/uniform_quant_ops.cc | 3 +- tensorflow/core/public/version.h | 2 +- tensorflow/core/tfrt/common/BUILD | 48 +- .../core/tfrt/common/async_value_tensor.cc | 5 + .../core/tfrt/common/async_value_tensor.h | 3 + .../tfrt/common/create_pjrt_client_util.cc | 6 + .../tfrt/common/create_pjrt_client_util.h | 2 +- .../common/create_pjrt_client_util_test.cc | 3 +- tensorflow/core/tfrt/common/global_state.cc | 3 +- .../tfrt/common/pjrt_client_factory_options.h | 2 - .../common/pjrt_client_factory_registry.cc | 11 +- .../common/pjrt_client_factory_registry.h | 3 +- .../common/pjrt_cpu_client_registration.cc | 6 +- .../pjrt_cpu_client_registration_test.cc | 4 +- .../common/pjrt_gpu_client_registration.cc | 3 +- .../pjrt_gpu_client_registration_test.cc | 4 +- tensorflow/core/tfrt/common/pjrt_state.cc | 6 + tensorflow/core/tfrt/common/pjrt_state.h | 8 + .../core/tfrt/common/pjrt_state_test.cc | 11 +- tensorflow/core/tfrt/common/pjrt_util.cc | 6 +- tensorflow/core/tfrt/common/pjrt_util.h | 3 +- tensorflow/core/tfrt/common/pjrt_util_test.cc | 4 +- .../tfrt/graph_executor/graph_executor.cc | 3 +- tensorflow/lite/CMakeLists.txt | 7 + .../lite/delegates/gpu/cl/opencl_wrapper.cc | 34 +- .../utils/experimental/stable_delegate/BUILD | 1 + .../stable_delegate/kernel_test_main.cc | 15 +- .../lite/delegates/xnnpack/conv_2d_test.cc | 35 - .../delegates/xnnpack/xnnpack_delegate.cc | 7 - .../lite/delegates/xnnpack/xnnpack_delegate.h | 2 - tensorflow/lite/kernels/test_util.cc | 7 +- tensorflow/lite/kernels/test_util.h | 9 + tensorflow/lite/profiling/BUILD | 3 + .../lite/profiling/profile_summarizer.cc | 2 + .../lite/profiling/profile_summarizer.h | 19 +- .../profiling/profile_summary_formatter.cc | 225 +- .../profiling/profile_summary_formatter.h | 81 +- .../profile_summary_formatter_test.cc | 260 ++- tensorflow/lite/profiling/proto/BUILD | 41 + .../lite/profiling/proto/CMakeLists.txt | 41 + .../lite/profiling/proto/profiling_info.proto | 63 + tensorflow/lite/python/BUILD | 1 + tensorflow/lite/schema/BUILD | 22 - tensorflow/lite/tools/BUILD | 1 + .../lite/tools/benchmark/CMakeLists.txt | 9 + tensorflow/lite/tools/benchmark/README.md | 17 + .../tools/benchmark/benchmark_tflite_model.cc | 74 +- .../tools/benchmark/profiling_listener.cc | 32 +- .../lite/tools/benchmark/profiling_listener.h | 8 +- .../tools/cmake/modules/FindProtobuf.cmake | 16 + .../lite/tools/cmake/modules/protobuf.cmake | 45 + tensorflow/python/compat/compat.py | 2 +- third_party/llvm/generated.patch | 1901 +++++++++++++++-- third_party/llvm/workspace.bzl | 4 +- third_party/py/python_init_repositories.bzl | 4 +- third_party/py/python_repo.bzl | 148 +- third_party/xla/.kokoro/linux/build.sh | 14 +- third_party/xla/docs/custom_call.md | 2 +- third_party/xla/docs/indexing.md | 238 ++- .../py/python_init_repositories.bzl | 4 +- .../xla/third_party/py/python_repo.bzl | 148 +- .../py/python_init_repositories.bzl | 4 +- .../tsl/third_party/py/python_repo.bzl | 148 +- third_party/xla/xla/debug_options_flags.cc | 13 +- third_party/xla/xla/pjrt/cpu/BUILD | 4 + third_party/xla/xla/pjrt/cpu/cpu_client.cc | 87 +- third_party/xla/xla/python/BUILD | 20 +- third_party/xla/xla/python/ifrt/sharding.cc | 43 +- third_party/xla/xla/python/ifrt/sharding.h | 32 +- .../xla/xla/python/ifrt/sharding_serdes.cc | 11 +- .../xla/xla/python/ifrt/sharding_serdes.proto | 1 + .../xla/python/ifrt/sharding_serdes_test.cc | 9 +- .../xla/xla/python/ifrt/sharding_test.cc | 89 +- third_party/xla/xla/python/pjrt_ifrt/BUILD | 7 + .../xla/xla/python/pjrt_ifrt/xla_sharding.cc | 18 + .../xla/xla/python/pjrt_ifrt/xla_sharding.h | 11 +- .../xla/python/pjrt_ifrt/xla_sharding_test.cc | 35 + third_party/xla/xla/python/profiler.cc | 24 +- third_party/xla/xla/python/profiler_utils.cc | 56 + third_party/xla/xla/python/profiler_utils.h | 27 + third_party/xla/xla/service/BUILD | 3 + third_party/xla/xla/service/cpu/BUILD | 45 + .../xla/xla/service/cpu/cpu_compiler.cc | 42 +- .../xla/xla/service/cpu/cpu_executable.cc | 51 +- .../xla/xla/service/cpu/cpu_executable.h | 55 +- third_party/xla/xla/service/cpu/runtime/BUILD | 4 + .../xla/xla/service/cpu/runtime/copy_thunk.cc | 9 +- .../xla/xla/service/cpu/runtime/thunk.cc | 25 + .../xla/xla/service/cpu/runtime/thunk.h | 28 +- .../xla/xla/service/cpu/thunk_emitter.cc | 95 + .../xla/xla/service/cpu/thunk_emitter.h | 59 + .../gpu/conv_layout_normalization_test.cc | 32 +- .../xla/service/gpu/cudnn_fusion_compiler.cc | 122 +- .../xla/xla/service/gpu/fusions/cudnn_test.cc | 26 + .../xla/service/gpu/gemm_fusion_autotuner.cc | 5 +- .../xla/xla/service/gpu/gpu_compiler.cc | 56 +- .../xla/xla/service/gpu/instruction_fusion.cc | 27 +- .../xla/xla/service/gpu/instruction_fusion.h | 13 +- .../service/gpu/instruction_fusion_test.cc | 22 + .../xla/xla/service/gpu/ir_emitter_triton.cc | 4 + .../xla/service/gpu/ir_emitter_triton_test.cc | 14 +- third_party/xla/xla/service/gpu/model/BUILD | 2 + .../service/gpu/model/coalescing_analysis.cc | 65 +- .../service/gpu/model/coalescing_analysis.h | 1 + .../gpu/model/coalescing_analysis_test.cc | 58 + .../model/gpu_indexing_performance_model.cc | 8 +- .../gpu/model/indexing_analysis_test.cc | 44 +- .../xla/xla/service/gpu/model/indexing_map.cc | 118 +- .../xla/xla/service/gpu/model/indexing_map.h | 50 +- .../service/gpu/model/indexing_map_test.cc | 34 +- .../service/gpu/model/indexing_test_utils.cc | 2 +- third_party/xla/xla/service/gpu/runtime/BUILD | 2 + .../xla/service/gpu/runtime/while_thunk.cc | 29 +- .../xla/xla/service/gpu/runtime/while_thunk.h | 7 + third_party/xla/xla/service/gpu/tests/BUILD | 6 + .../xla/xla/service/gpu/tests/dot_bf16.hlo | 4 +- .../service/gpu/tests/reduce_atomic_min.hlo | 415 ++-- .../gpu/tests/reduce_large_row_to_scalar.hlo | 510 +++-- .../xla/xla/service/gpu/triton_support.cc | 2 +- third_party/xla/xla/service/host_offloader.cc | 1587 +++++++------- third_party/xla/xla/service/host_offloader.h | 170 +- .../xla/xla/service/host_offloader_test.cc | 1149 ++++++++-- .../xla/service/latency_hiding_scheduler.cc | 8 + .../xla/service/latency_hiding_scheduler.h | 1 + .../profile_guided_latency_estimator.cc | 3 +- .../profile_guided_latency_estimator_test.cc | 56 + .../xla/stream_executor/device_description.h | 3 +- third_party/xla/xla/tests/exhaustive/BUILD | 32 +- .../tests/exhaustive/exhaustive_test_main.cc | 33 + .../exhaustive_unary_test_f32_or_smaller.cc | 46 +- .../coordination/coordination_service.cc | 8 +- third_party/xla/xla/xla.proto | 9 +- third_party/xla/xla/xla_data.proto | 4 +- 143 files changed, 8669 insertions(+), 2274 deletions(-) rename tensorflow/{ => compiler/mlir}/lite/schema/flatbuffer_compatibility_test.cc (95%) create mode 100644 tensorflow/compiler/mlir/lite/schema/schema_v3b.fbs create mode 100644 tensorflow/lite/profiling/proto/BUILD create mode 100644 tensorflow/lite/profiling/proto/CMakeLists.txt create mode 100644 tensorflow/lite/profiling/proto/profiling_info.proto create mode 100644 tensorflow/lite/tools/cmake/modules/FindProtobuf.cmake create mode 100644 tensorflow/lite/tools/cmake/modules/protobuf.cmake create mode 100644 third_party/xla/xla/python/profiler_utils.cc create mode 100644 third_party/xla/xla/python/profiler_utils.h create mode 100644 third_party/xla/xla/service/cpu/thunk_emitter.cc create mode 100644 third_party/xla/xla/service/cpu/thunk_emitter.h create mode 100644 third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc diff --git a/tensorflow/compiler/mlir/lite/schema/BUILD b/tensorflow/compiler/mlir/lite/schema/BUILD index 17a6bdb636959d..7cbc2253a83821 100644 --- a/tensorflow/compiler/mlir/lite/schema/BUILD +++ b/tensorflow/compiler/mlir/lite/schema/BUILD @@ -1,4 +1,5 @@ load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( @@ -69,3 +70,25 @@ cc_library( "@flatbuffers", ], ) + +# Schema test to make sure we don't introduce backward incompatible changes +# to schemas. +tf_cc_test( + name = "flatbuffer_compatibility_test", + size = "small", + srcs = ["flatbuffer_compatibility_test.cc"], + data = [ + "schema.fbs", + "schema_v3b.fbs", + ], + tags = [ + "no_oss", + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + "//tensorflow/core/platform", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:flatc_library", + ], +) diff --git a/tensorflow/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/compiler/mlir/lite/schema/flatbuffer_compatibility_test.cc similarity index 95% rename from tensorflow/lite/schema/flatbuffer_compatibility_test.cc rename to tensorflow/compiler/mlir/lite/schema/flatbuffer_compatibility_test.cc index 976c2b302c1a6e..c2eea199bc6401 100644 --- a/tensorflow/lite/schema/flatbuffer_compatibility_test.cc +++ b/tensorflow/compiler/mlir/lite/schema/flatbuffer_compatibility_test.cc @@ -63,9 +63,10 @@ TEST(SchemaTest, TestCompatibility) { // Read file contents of schemas into strings // TODO(aselle): Need a reliable way to load files. std::string base_contents, current_contents; - const char *base_filename = TFLITE_TF_PREFIX "lite/schema/schema_v3b.fbs"; + const char *base_filename = TFLITE_TF_PREFIX + "compiler/mlir/lite/schema/schema_v3b.fbs"; const char *current_filename = - TFLITE_TF_PREFIX "lite/schema/schema.fbs"; + TFLITE_TF_PREFIX "compiler/mlir/lite/schema/schema.fbs"; ASSERT_TRUE(LoadFileRaw(base_filename, &base_contents)); ASSERT_TRUE(LoadFileRaw(current_filename, ¤t_contents)); diff --git a/tensorflow/compiler/mlir/lite/schema/schema_v3b.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3b.fbs new file mode 100644 index 00000000000000..917786050f7e8b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/schema/schema_v3b.fbs @@ -0,0 +1,1242 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. +// Version 3b: Rename fields in SignatureDef. Has backward compatibility with +// version 3 and 3a. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, + COMPLEX128 = 11, + UINT64 = 12, + // Experimental: Resource and variant types are experimental, that are subject + // to change. Do not implement custom kernels using resource & variant types + // now. + RESOURCE = 13, + VARIANT = 14, + UINT32 = 15, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +// LINT.IfChange +enum BuiltinOperator : int32 { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + CUMSUM = 128, + CALL_ONCE = 129, + BROADCAST_TO = 130, + RFFT2D = 131, + CONV_3D = 132, + IMAG=133, + REAL=134, + COMPLEX_ABS=135, + HASHTABLE = 136, + HASHTABLE_FIND = 137, + HASHTABLE_IMPORT = 138, + HASHTABLE_SIZE = 139, + REDUCE_ALL = 140, + CONV_3D_TRANSPOSE = 141, + VAR_HANDLE = 142, + READ_VARIABLE = 143, + ASSIGN_VARIABLE = 144, +} +// LINT.ThenChange(nnapi_linter/linter.proto) + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions, + CumsumOptions, + CallOnceOptions, + BroadcastToOptions, + Rfft2dOptions, + Conv3DOptions, + HashtableOptions, + HashtableFindOptions, + HashtableImportOptions, + HashtableSizeOptions, + VarHandleOptions, + ReadVariableOptions, + AssignVariableOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +// Options for both Conv3D and Conv3DTranspose. +table Conv3DOptions { + padding:Padding; + stride_d:int; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_d_factor:int = 1; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 3. + pot_scale_int16:bool = true; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + // This field is currently ignored in the L2 Norm Op. + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 5 + pot_scale_int16:bool = true; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; + // Parameters for Gather version 5 or above. + batch_dims: int = 0; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table CallOnceOptions { + init_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; + // Parameters for BatchMatMul version 4 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table CumsumOptions { + exclusive:bool; + reverse:bool; +} + +table BroadcastToOptions { +} + +table Rfft2dOptions { +} + +table HashtableOptions { + // The identity of hash tables. This identity will be used across different + // subgraphs in the same interpreter instance. + table_id:int; + key_dtype:TensorType; + value_dtype:TensorType; +} + +table HashtableFindOptions { +} + +table HashtableImportOptions { +} + +table HashtableSizeOptions { +} + +table VarHandleOptions { + container:string; + shared_name:string; +} + +table ReadVariableOptions { +} + +table AssignVariableOptions { +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem + // (the original BuiltinOperator enum field was represented as a byte). + // This field will be used when the value of the extended builtin_code field + // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +// Map from an alias name of tensor to tensor index in the graph. +// This is used in Signature def. +table TensorMap { + // Represents the alias to use for this tensor. + name:string; + + // The actual tensor index in the primary graph, that 'name' corresponds to. + tensor_index:uint; +} + +// This corresponds to SignatureDef in Tensorflow SavedModel. +// The SignatureDef will be part of the SavedModel provided for conversion. +table SignatureDef { + // Named inputs for this signature. + inputs:[TensorMap]; + + // Named outputs for this signature. + outputs:[TensorMap]; + + // Key value which was in the Tensorflow SavedModel SignatureDef map. + signature_key:string; + + // Model tag, deprecated. + deprecated_tag:string (deprecated); + + // Index of subgraphs that corresponds to the exported method. + subgraph_index:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; + + // Optional SignatureDefs for the model. + signature_defs:[SignatureDef]; +} + +root_type Model; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 69f763abfda5a7..3790f64e0cec68 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -65,7 +65,6 @@ package_group( packages = [ "//tensorflow/...", "//tensorflow_text/...", - "//waymo/ml/compiler/frontend/kernels/...", "//waymo/onboard/ml/...", ], ) diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc index 9551bdd79d4ae5..c133556b4aaa43 100644 --- a/tensorflow/core/kernels/gather_nd_op.cc +++ b/tensorflow/core/kernels/gather_nd_op.cc @@ -45,8 +45,7 @@ class GatherNdOp : public OpKernel { Tensor out; OP_REQUIRES_OK( - c, functor::DoGatherNd( - c, params, indices, &out)); + c, functor::DoGatherNd(c, params, indices, &out)); c->set_output(0, out); } }; diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h index 6059a2bbdafb31..09bad00c59b070 100644 --- a/tensorflow/core/kernels/gather_nd_op.h +++ b/tensorflow/core/kernels/gather_nd_op.h @@ -43,8 +43,7 @@ struct GatherNdSlice { typename TTypes::Matrix Tout); }; -template +template Status DoGatherNd(OpKernelContext* c, const Tensor& params, const Tensor& indices, Tensor* out) { if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { @@ -152,10 +151,6 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params, indices_nd); } - if constexpr (kDropBadIndices) { - return absl::OkStatus(); - } - // bad_i will only return >= 0 on CPUs right now. if (bad_i >= 0) { auto shape = indices.shape(); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index ea369fd49a5ea2..0f604b0e605879 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -878,7 +878,7 @@ class IndexFlattener { namespace { template + scatter_nd_op::UpdateOp Op> Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate) { @@ -925,11 +925,7 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, for (int i = 0; i < IXDIM; ++i) { \ output_shape_prefix[i] = shape.dim_size(i); \ } \ - constexpr bool kShallDropBadIndices = \ - kDropBadIndices || std::is_same::value; \ - functor::ScatterNdFunctor \ - functor; \ + functor::ScatterNdFunctor functor; \ bad_i = \ functor(c->eigen_device(), slice_size, output_shape_prefix, \ output_matrix, indices_flat, updates_flat, output_matrix); \ @@ -951,9 +947,6 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, slice_dim); } } - if constexpr (kDropBadIndices) { - return absl::OkStatus(); - } if (bad_i >= 0) { auto slice_shape = indices.shape(); slice_shape.RemoveLastDims(1); @@ -977,8 +970,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, // back to GPU. This is useful because the CPU implementation is deterministic // and the GPU implementation is not. Tensor inputs to this function must be on // the GPU. -template +template Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate) { @@ -1023,7 +1015,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, } TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - TF_RETURN_IF_ERROR(DoScatterNd( + TF_RETURN_IF_ERROR(DoScatterNd( c, host_indices, host_updates, shape, &host_out, /*allocate=*/false)); // Copy 'host_out' to device. @@ -1041,15 +1033,15 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, } // namespace template + scatter_nd_op::UpdateOp Op> Status DoScatterNd(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same::value && tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) { - return DoScatterNdOnCpu( - c, indices, updates, shape, out, allocate); + return DoScatterNdOnCpu(c, indices, updates, shape, out, + allocate); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -1057,11 +1049,11 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, // atomics, which are not supported for all integer types. if constexpr (std::is_same::value && std::is_integral::value) { - return DoScatterNdOnCpu( - c, indices, updates, shape, out, allocate); + return DoScatterNdOnCpu(c, indices, updates, shape, out, + allocate); } else { - return DoScatterNdImpl( - c, indices, updates, shape, out, allocate); + return DoScatterNdImpl(c, indices, updates, shape, + out, allocate); } } } // namespace functor @@ -1069,29 +1061,16 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ - template <> \ - Index \ - ScatterNdFunctor:: \ - operator()(const GPUDevice& d, const Index slice_size, \ - const Eigen::array output_shape_prefix, \ - typename TTypes::Tensor Tparams, \ - typename TTypes::ConstTensor Tindices, \ - typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput); \ - extern template struct ScatterNdFunctor; \ - template <> \ - Index ScatterNdFunctor:: \ - operator()(const GPUDevice& d, const Index slice_size, \ - const Eigen::array output_shape_prefix, \ - typename TTypes::Tensor Tparams, \ - typename TTypes::ConstTensor Tindices, \ - typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput); \ - extern template struct ScatterNdFunctor; +#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ + template <> \ + Index ScatterNdFunctor::operator()( \ + const GPUDevice& d, const Index slice_size, \ + const Eigen::array output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput); \ + extern template struct ScatterNdFunctor; #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ diff --git a/tensorflow/core/kernels/scatter_nd_op.h b/tensorflow/core/kernels/scatter_nd_op.h index 8d2e74b18ca864..f9a2ce0ed6e12b 100644 --- a/tensorflow/core/kernels/scatter_nd_op.h +++ b/tensorflow/core/kernels/scatter_nd_op.h @@ -44,7 +44,7 @@ namespace functor { // Functor used by ScatterOp to do the computations. template + scatter_nd_op::UpdateOp op, int IXDIM> struct ScatterNdFunctor { // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. Index operator()( @@ -63,7 +63,7 @@ struct ScatterNdFunctor { // right type (T) and shape. This tensor will not be zeroed out // before the scatter is executed. template + scatter_nd_op::UpdateOp Op> Status DoScatterNd(OpKernelContext* c, const Tensor& indices, const Tensor& updates, const TensorShape& shape, Tensor* out, bool allocate); diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h index abdbc1ece968bf..b0123780cc6406 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -103,9 +103,8 @@ class UpdateExecutor { namespace functor { // Implementation of update functor for CPU. -template -struct ScatterNdFunctor { +template +struct ScatterNdFunctor { Index operator()( const CPUDevice& d, const Index slice_size, const Eigen::array output_shape_prefix, @@ -137,44 +136,33 @@ struct ScatterNdFunctor { i += ix_d * batch_strides[dim]; } if (TF_PREDICT_FALSE(out_of_bounds)) { - if constexpr (kDropBadIndices) { - continue; - } error_loc = loc; break; + } else { + auto input_chip = Toutput.template chip<0>(i); + auto output_chip = input_chip; + auto update_chip = Tupdates.template chip<0>(loc); + update_executor::UpdateExecutor< + CPUDevice, decltype(input_chip), decltype(update_chip), + decltype(output_chip), OP>::Execute(d, input_chip, update_chip, + output_chip); } - auto input_chip = Toutput.template chip<0>(i); - auto output_chip = input_chip; - auto update_chip = Tupdates.template chip<0>(loc); - update_executor::UpdateExecutor< - CPUDevice, decltype(input_chip), decltype(update_chip), - decltype(output_chip), OP>::Execute(d, input_chip, update_chip, - output_chip); } return error_loc; } }; -#define REGISTER_SCATTER_ND_FULL(T, Index, op) \ - template Index ScatterNdFunctor:: \ - operator()(const CPUDevice& d, const Index slice_size, \ - const Eigen::array \ - output_shape_prefix, \ - typename TTypes::Tensor Tparams, \ - typename TTypes::ConstTensor Tindices, \ - typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput); \ - template Index ScatterNdFunctor:: \ - operator()(const CPUDevice& d, const Index slice_size, \ - const Eigen::array \ - output_shape_prefix, \ - typename TTypes::Tensor Tparams, \ - typename TTypes::ConstTensor Tindices, \ - typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput) +#define REGISTER_SCATTER_ND_FULL(T, Index, op) \ + template Index \ + ScatterNdFunctor::operator()( \ + const CPUDevice& d, const Index slice_size, \ + const Eigen::array \ + output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput) #define REGISTER_SCATTER_ND_INDEX(type, op) \ REGISTER_SCATTER_ND_FULL(type, int32, op); \ diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc index 4e528c58e6ba0f..fd1d4747c40982 100644 --- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc @@ -124,9 +124,8 @@ __global__ void ScatterNdOpKernel( namespace functor { // Functor used by ScatterOp to do the computations. -template -struct ScatterNdFunctor { +template +struct ScatterNdFunctor { Index operator()( const GPUDevice& d, const Index slice_size, const Eigen::array output_shape_prefix, @@ -165,9 +164,8 @@ struct ScatterNdFunctor { } // namespace functor -#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ - template struct functor::ScatterNdFunctor; +#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \ + template struct functor::ScatterNdFunctor; #define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \ DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \ diff --git a/tensorflow/core/ops/uniform_quant_ops.cc b/tensorflow/core/ops/uniform_quant_ops.cc index 514c9f9278d8c5..c5fcb762dabd13 100644 --- a/tensorflow/core/ops/uniform_quant_ops.cc +++ b/tensorflow/core/ops/uniform_quant_ops.cc @@ -29,7 +29,8 @@ using tensorflow::errors::Unknown; // If the rank and all dim sizes are known, return corresponding TensorShape. // Otherwise return Unknown error. -StatusOr ToTensorShape(ShapeHandle shape_handle, int64_t rank) { +absl::StatusOr ToTensorShape(ShapeHandle shape_handle, + int64_t rank) { TensorShape shape; for (int i = 0; i < rank; ++i) { int64_t dim_size = shape_inference::InferenceContext::Value( diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 0ab15886f47593..b188b0142e52a7 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1872 // Updated: 2024/5/24 +#define TF_GRAPH_DEF_VERSION 1876 // Updated: 2024/5/28 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD index 8129cdb0ea0f65..ac5a88c0d326f6 100644 --- a/tensorflow/core/tfrt/common/BUILD +++ b/tensorflow/core/tfrt/common/BUILD @@ -46,8 +46,6 @@ cc_library( visibility = [":friends"], deps = [ "//tensorflow/core:framework", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", "@local_xla//xla/pjrt:utils", "@tf_runtime//:hostcontext", ], @@ -64,6 +62,8 @@ cc_library( visibility = [":friends"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", + "@com_google_absl//absl/log:check", "@local_xla//xla/pjrt:pjrt_client", "@tf_runtime//:hostcontext", "@tf_runtime//:support", @@ -96,9 +96,15 @@ cc_library( ":pjrt_client_factory_options", ":pjrt_client_factory_registry", "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/framework:resource_base", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/client:local_client", "@local_xla//xla/pjrt:local_device_state", "@local_xla//xla/pjrt:pjrt_client", @@ -121,11 +127,14 @@ cc_library( deps = [ ":global_state", ":pjrt_state", + "//tensorflow/core:framework", "//tensorflow/core:framework_types_hdr", "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:refcount", "//tensorflow/core/platform:status", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", "@local_xla//xla/pjrt:pjrt_client", ], ) @@ -144,7 +153,12 @@ cc_library( deps = [ ":global_state", ":pjrt_state", + "//tensorflow/core:framework", "//tensorflow/core:framework_types_hdr", + "//tensorflow/core/platform:refcount", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", "@local_xla//xla/pjrt:pjrt_client", ], ) @@ -153,16 +167,17 @@ tf_cc_test( name = "pjrt_state_test", srcs = ["pjrt_state_test.cc"], deps = [ - ":global_state", + ":pjrt_cpu_client_registration", ":pjrt_state", "//tensorflow/core:framework_types_hdr", "//tensorflow/core:test", - "//tensorflow/core/platform:status_matchers", + "//tensorflow/core/platform:refcount", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", - "//tensorflow/core/tfrt/common:pjrt_cpu_client_registration", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", - "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@local_xla//xla/pjrt/cpu:cpu_client", ], ) @@ -170,7 +185,6 @@ tf_cc_test( name = "pjrt_util_test", srcs = ["pjrt_util_test.cc"], deps = [ - ":global_state", ":pjrt_state", ":pjrt_util", "//tensorflow/core:framework", @@ -180,7 +194,7 @@ tf_cc_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@local_xla//xla/pjrt/cpu:cpu_client", ], ) @@ -199,6 +213,7 @@ tf_cuda_cc_test( ":pjrt_state", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", + "@com_google_absl//absl/strings", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", @@ -223,9 +238,12 @@ cc_library( ":pjrt_client_factory_options", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@local_tsl//tsl/framework:device_type", - "@local_tsl//tsl/platform:statusor", - "@local_xla//xla:statusor", + "@local_tsl//tsl/platform:errors", "@local_xla//xla/pjrt:pjrt_client", ], ) @@ -237,9 +255,10 @@ cc_library( ":pjrt_client_factory_options", ":pjrt_client_factory_registry", "//tensorflow/core:framework_types_hdr", - "@local_xla//xla:statusor", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", - "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", + "@local_xla//xla/pjrt/cpu:cpu_client", ], alwayslink = True, ) @@ -253,6 +272,7 @@ tf_cc_test( ":pjrt_cpu_client_registration", "//tensorflow/core:framework_types_hdr", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) @@ -264,7 +284,8 @@ cc_library( ":pjrt_client_factory_registry", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework_types_hdr", - "@local_xla//xla:statusor", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client", ], @@ -296,6 +317,7 @@ tf_cuda_cc_test( ":pjrt_gpu_client_registration", "//tensorflow/core:framework_types_hdr", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla/service:gpu_plugin", ], ) diff --git a/tensorflow/core/tfrt/common/async_value_tensor.cc b/tensorflow/core/tfrt/common/async_value_tensor.cc index d78c41051d29b8..09b86690157ff0 100644 --- a/tensorflow/core/tfrt/common/async_value_tensor.cc +++ b/tensorflow/core/tfrt/common/async_value_tensor.cc @@ -14,11 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/common/async_value_tensor.h" +#include +#include #include #include +#include "absl/log/check.h" #include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/tensor.h" #include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/core/tfrt/common/async_value_tensor.h b/tensorflow/core/tfrt/common/async_value_tensor.h index 25ce153b516298..06e99f8f7bcc48 100644 --- a/tensorflow/core/tfrt/common/async_value_tensor.h +++ b/tensorflow/core/tfrt/common/async_value_tensor.h @@ -15,10 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_COMMON_ASYNC_VALUE_TENSOR_H_ #define TENSORFLOW_CORE_TFRT_COMMON_ASYNC_VALUE_TENSOR_H_ +#include #include #include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/types.h" #include "tfrt/support/forward_decls.h" // from @tf_runtime #include "tfrt/support/ref_count.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/common/create_pjrt_client_util.cc b/tensorflow/core/tfrt/common/create_pjrt_client_util.cc index 73f7dfc6de0e3e..b611b183de9032 100644 --- a/tensorflow/core/tfrt/common/create_pjrt_client_util.cc +++ b/tensorflow/core/tfrt/common/create_pjrt_client_util.cc @@ -17,9 +17,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/tfrt/common/global_state.h" #include "tensorflow/core/tfrt/common/pjrt_state.h" +#include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/common/create_pjrt_client_util.h b/tensorflow/core/tfrt/common/create_pjrt_client_util.h index 945cea4efd4098..fe8dfbb8db5f23 100644 --- a/tensorflow/core/tfrt/common/create_pjrt_client_util.h +++ b/tensorflow/core/tfrt/common/create_pjrt_client_util.h @@ -15,10 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_COMMON_CREATE_PJRT_CLIENT_UTIL_H_ #define TENSORFLOW_CORE_TFRT_COMMON_CREATE_PJRT_CLIENT_UTIL_H_ -#include #include #include +#include "absl/status/statusor.h" #include "xla/pjrt/pjrt_client.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/tfrt/common/create_pjrt_client_util_test.cc b/tensorflow/core/tfrt/common/create_pjrt_client_util_test.cc index 4eab11a48c411b..027bf7bed783aa 100644 --- a/tensorflow/core/tfrt/common/create_pjrt_client_util_test.cc +++ b/tensorflow/core/tfrt/common/create_pjrt_client_util_test.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" // IWYU pragma: keep #include "tensorflow/core/framework/types.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/tfrt/common/global_state.cc b/tensorflow/core/tfrt/common/global_state.cc index 75d15d010234bf..61279217c06325 100644 --- a/tensorflow/core/tfrt/common/global_state.cc +++ b/tensorflow/core/tfrt/common/global_state.cc @@ -17,9 +17,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "xla/pjrt/utils.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/host_allocator.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/common/pjrt_client_factory_options.h b/tensorflow/core/tfrt/common/pjrt_client_factory_options.h index 47caf2116af6b7..70e3092c2df654 100644 --- a/tensorflow/core/tfrt/common/pjrt_client_factory_options.h +++ b/tensorflow/core/tfrt/common/pjrt_client_factory_options.h @@ -15,8 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_OPTIONS_H_ #define TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_OPTIONS_H_ -#include -#include #include #include #include diff --git a/tensorflow/core/tfrt/common/pjrt_client_factory_registry.cc b/tensorflow/core/tfrt/common/pjrt_client_factory_registry.cc index d792a9b2f6b5e6..bea5b42e7b4c20 100644 --- a/tensorflow/core/tfrt/common/pjrt_client_factory_registry.cc +++ b/tensorflow/core/tfrt/common/pjrt_client_factory_registry.cc @@ -16,9 +16,16 @@ limitations under the License. #include #include -#include -#include "tsl/platform/statusor.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/registration/registration.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" +#include "tsl/framework/device_type.h" +#include "tsl/platform/errors.h" namespace xla { PjrtClientFactoryRegistry& PjrtClientFactoryRegistry::Get() { diff --git a/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h b/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h index 2950772b1ea6f2..01568d11ec1b51 100644 --- a/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h +++ b/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h @@ -19,8 +19,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/statusor.h" #include "tensorflow/core/framework/registration/registration.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" diff --git a/tensorflow/core/tfrt/common/pjrt_cpu_client_registration.cc b/tensorflow/core/tfrt/common/pjrt_cpu_client_registration.cc index b114821d2f20ec..75bfa24a6b6ad3 100644 --- a/tensorflow/core/tfrt/common/pjrt_cpu_client_registration.cc +++ b/tensorflow/core/tfrt/common/pjrt_cpu_client_registration.cc @@ -16,11 +16,13 @@ limitations under the License. #include #include -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" -#include "xla/statusor.h" +#include "absl/status/statusor.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/pjrt_client.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_registry.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/tensorflow/core/tfrt/common/pjrt_cpu_client_registration_test.cc b/tensorflow/core/tfrt/common/pjrt_cpu_client_registration_test.cc index 26d6884e91006c..773d1223507038 100644 --- a/tensorflow/core/tfrt/common/pjrt_cpu_client_registration_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_cpu_client_registration_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_registry.h" +#include "tsl/framework/device_type.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc b/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc index ead40c6f39c254..99b1fab73f6052 100644 --- a/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc +++ b/tensorflow/core/tfrt/common/pjrt_gpu_client_registration.cc @@ -16,13 +16,14 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/statusor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_registry.h" +#include "tsl/platform/statusor.h" namespace xla { absl::StatusOr> GetGpuClient( diff --git a/tensorflow/core/tfrt/common/pjrt_gpu_client_registration_test.cc b/tensorflow/core/tfrt/common/pjrt_gpu_client_registration_test.cc index f4feb34b541cb7..2eeca7a71eca12 100644 --- a/tensorflow/core/tfrt/common/pjrt_gpu_client_registration_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_gpu_client_registration_test.cc @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_registry.h" +#include "tsl/framework/device_type.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/tensorflow/core/tfrt/common/pjrt_state.cc b/tensorflow/core/tfrt/common/pjrt_state.cc index a1a8e2366c6a38..12a8937d389c9a 100644 --- a/tensorflow/core/tfrt/common/pjrt_state.cc +++ b/tensorflow/core/tfrt/common/pjrt_state.cc @@ -18,11 +18,17 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/tf_pjrt_client.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" #include "tensorflow/core/tfrt/common/pjrt_client_factory_registry.h" +#include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/common/pjrt_state.h b/tensorflow/core/tfrt/common/pjrt_state.h index 180163376b4cd2..4863fc9e7d7e0c 100644 --- a/tensorflow/core/tfrt/common/pjrt_state.h +++ b/tensorflow/core/tfrt/common/pjrt_state.h @@ -17,14 +17,22 @@ limitations under the License. #include #include +#include #include +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/client/local_client.h" #include "xla/pjrt/local_device_state.h" #include "xla/pjrt/pjrt_client.h" #include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include "tensorflow/core/framework/resource_base.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/framework/allocator.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/common/pjrt_state_test.cc b/tensorflow/core/tfrt/common/pjrt_state_test.cc index 0b8cf6e1b9bbf8..fddd72ea050509 100644 --- a/tensorflow/core/tfrt/common/pjrt_state_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_state_test.cc @@ -19,19 +19,20 @@ limitations under the License. #include #include +#include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status_matchers.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace { using tensorflow::PjRtState; using ::testing::HasSubstr; - -using ::tensorflow::testing::StatusIs; +using ::tsl::testing::StatusIs; class PjRtStateTestFixture : public testing::Test { protected: diff --git a/tensorflow/core/tfrt/common/pjrt_util.cc b/tensorflow/core/tfrt/common/pjrt_util.cc index 643632d5706a47..54ed3060adbc08 100644 --- a/tensorflow/core/tfrt/common/pjrt_util.cc +++ b/tensorflow/core/tfrt/common/pjrt_util.cc @@ -15,17 +15,19 @@ limitations under the License. #include "tensorflow/core/tfrt/common/pjrt_util.h" #include -#include -#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/common/global_state.h" #include "tensorflow/core/tfrt/common/pjrt_state.h" +#include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/common/pjrt_util.h b/tensorflow/core/tfrt/common/pjrt_util.h index ce9cbc1d11c287..2895f22bf4ea92 100644 --- a/tensorflow/core/tfrt/common/pjrt_util.h +++ b/tensorflow/core/tfrt/common/pjrt_util.h @@ -16,9 +16,8 @@ limitations under the License. #define TENSORFLOW_CORE_TFRT_COMMON_PJRT_UTIL_H_ #include -#include -#include +#include "absl/status/statusor.h" #include "xla/pjrt/pjrt_client.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/tfrt/common/pjrt_util_test.cc b/tensorflow/core/tfrt/common/pjrt_util_test.cc index f8de14dd034812..1361b72c2da686 100644 --- a/tensorflow/core/tfrt/common/pjrt_util_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_util_test.cc @@ -17,10 +17,8 @@ limitations under the License. #include #include -#include "xla/pjrt/tfrt_cpu_pjrt_client.h" -#include "tensorflow/core/framework/resource_mgr.h" +#include "xla/pjrt/cpu/cpu_client.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/tfrt/common/global_state.h" #include "tensorflow/core/tfrt/common/pjrt_state.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 979590bf83aac7..5580e69e4681cb 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -1130,8 +1130,7 @@ GraphExecutor::LoadedClientGraph::LoadedClientGraph( pflr_(&graph_executor->fallback_state().device_manager(), graph_executor->fallback_state().session_options().env, &graph_executor->fallback_state().session_options().config, - TF_GRAPH_DEF_VERSION, - &graph_executor->fallback_state().func_lib_def(), + TF_GRAPH_DEF_VERSION, &flib_def_, graph_executor->fallback_state() .session_options() .config.graph_options() diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 9f60eb3ac4d235..e290b0967f75dd 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -711,6 +711,13 @@ if(TFLITE_KERNEL_TEST) add_subdirectory(${TFLITE_SOURCE_DIR}/kernels) endif() +# Add the generated headers directory. Required for maintaining the +# tensorflow/lite directory structure for generated headers. +set(TFLITE_GENERATED_HEADERS_DIR ${CMAKE_BINARY_DIR}/tensorflow/lite) + +# Add the profiling proto directory. +add_subdirectory(${TFLITE_SOURCE_DIR}/profiling/proto) + # The benchmark tool. add_subdirectory(${TFLITE_SOURCE_DIR}/tools/benchmark) diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index 2419b2c9325ad3..8b4de50df0bd84 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -115,23 +115,6 @@ absl::Status LoadOpenCL() { } #else void* libopencl = nullptr; -#ifdef __ANDROID__ - // Pixel phone or auto? - libopencl = - AndroidDlopenSphalLibrary("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); - if (!libopencl) { - libopencl = - AndroidDlopenSphalLibrary("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); - } - if (libopencl) { - typedef void (*enableOpenCL_t)(); - enableOpenCL_t enableOpenCL = - reinterpret_cast(dlsym(libopencl, "enableOpenCL")); - enableOpenCL(); - LoadOpenCLFunctions(libopencl, true); - return absl::OkStatus(); - } -#endif #ifdef __APPLE__ static const char* kClLibName = "/System/Library/Frameworks/OpenCL.framework/OpenCL"; @@ -140,6 +123,23 @@ absl::Status LoadOpenCL() { #endif #ifdef __ANDROID__ libopencl = AndroidDlopenSphalLibrary(kClLibName, RTLD_NOW | RTLD_LOCAL); + if (!libopencl) { + // Legacy Pixel phone or auto path? + libopencl = + AndroidDlopenSphalLibrary("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); + if (!libopencl) { + libopencl = + AndroidDlopenSphalLibrary("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); + } + if (libopencl) { + typedef void (*enableOpenCL_t)(); + enableOpenCL_t enableOpenCL = + reinterpret_cast(dlsym(libopencl, "enableOpenCL")); + enableOpenCL(); + LoadOpenCLFunctions(libopencl, true); + return absl::OkStatus(); + } + } #else libopencl = dlopen(kClLibName, RTLD_NOW | RTLD_LOCAL); #endif diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD b/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD index 9b0cda01bc559a..f37fd78e0f613f 100644 --- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD +++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD @@ -133,6 +133,7 @@ cc_library( srcs = ["kernel_test_main.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/kernels:acceleration_test_util", "//tensorflow/lite/kernels:acceleration_test_util_internal", "//tensorflow/lite/kernels:test_delegate_providers_lib", diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc b/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc index 3c0d4c5a93f2ee..f3fe76d395a79e 100644 --- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc +++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc @@ -15,8 +15,10 @@ limitations under the License. #include #include +#include #include #include "benchmark/benchmark.h" // from @com_google_benchmark +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/kernels/acceleration_test_util.h" #include "tensorflow/lite/kernels/acceleration_test_util_internal.h" #include "tensorflow/lite/kernels/test_delegate_providers.h" @@ -84,7 +86,16 @@ void ValidateAcceleration(const SingleOpModel& model) { GetAccelerationTestParam(test_id) .has_value(); if (!supported) { + // Note that the error `kTfLiteApplicationError` is accepted here. + // We only want to check the delegate is working properly, so an error due + // to incompatibility between the model and the delegate is not considered a + // failure here. + EXPECT_THAT(model.GetDelegateApplicationStatus().value_or(kTfLiteOk), + testing::AnyOf(kTfLiteOk, kTfLiteApplicationError)); return; + } else { + EXPECT_EQ(model.GetDelegateApplicationStatus().value_or(kTfLiteOk), + kTfLiteOk); } // If we have multiple delegates applied, we would skip this check at the @@ -135,9 +146,7 @@ bool InitKernelTest(int* argc, char** argv) { return true; } -void DestroyKernelTest() { - DelegateTestSuiteAccelerationTestParams::Destroy(); -} +void DestroyKernelTest() { DelegateTestSuiteAccelerationTestParams::Destroy(); } } // namespace } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc index 5654c285c8d150..cab06da2807b8d 100644 --- a/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/conv_2d_test.cc @@ -816,40 +816,5 @@ TEST(Conv2D, TransientIndirectionBuffer) { .Test(xnnpack_delegate.get()); } -TEST(Conv2D, AdaptiveAvxOptimization) { - TfLiteXNNPackDelegateOptions xnnpack_options = - TfLiteXNNPackDelegateOptionsDefault(); - xnnpack_options.num_threads = 2; - xnnpack_options.experimental_adaptive_avx_optimization = true; - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(&xnnpack_options), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto batch_rng = - std::bind(std::uniform_int_distribution(2, 4), std::ref(rng)); - auto input_rng = - std::bind(std::uniform_int_distribution(5, 25), std::ref(rng)); - auto kernel_rng = - std::bind(std::uniform_int_distribution(3, 5), std::ref(rng)); - auto stride_rng = - std::bind(std::uniform_int_distribution(2, 3), std::ref(rng)); - auto channel_rng = - std::bind(std::uniform_int_distribution(2, 16), std::ref(rng)); - - Conv2DTester() - .BatchSize(batch_rng()) - .InputHeight(input_rng()) - .InputWidth(input_rng()) - .InputChannels(channel_rng()) - .OutputChannels(channel_rng()) - .KernelHeight(kernel_rng()) - .KernelWidth(kernel_rng()) - .StrideHeight(stride_rng()) - .StrideWidth(stride_rng()) - .Test(xnnpack_delegate.get()); -} - } // namespace xnnpack } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index ff54dee09a0fb5..33e1d317bce6c8 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -628,10 +628,6 @@ class Delegate { #endif } - bool experimental_adaptive_avx_optimization() const { - return options_.experimental_adaptive_avx_optimization; - } - pthreadpool_t threadpool() const { #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) return nullptr; @@ -1120,9 +1116,6 @@ class Subgraph { if (delegate.transient_indirection_buffer()) { flags |= XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER; } - if (delegate.experimental_adaptive_avx_optimization()) { - xnn_experiment_enable_adaptive_avx_optimization(); - } if (delegate.force_fp16()) { flags |= XNN_FLAG_FORCE_FP16_INFERENCE; } else { diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h index 55eddcf1a54d67..dd5bf1adc4f587 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h @@ -68,8 +68,6 @@ typedef struct { // Deprecated. Use the flags bitfield with the // TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS mask. bool handle_variable_ops; - // Enable adaptive optimization for AVX CPUs. - bool experimental_adaptive_avx_optimization; // Path to the weight cache to load if `weight_cache` is undefined. // // WARNING this is an experimental flag. diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index 1dd47692c50819..99ab45a13c71a7 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -278,7 +278,9 @@ TfLiteStatus SingleOpModel::ApplyDelegate() { if (delegate_) { TFLITE_LOG(WARN) << "Having a manually-set TfLite delegate, and bypassing " "KernelTestDelegateProviders"; - TF_LITE_ENSURE_STATUS(interpreter_->ModifyGraphWithDelegate(delegate_)); + SetDelegateApplicationStatus( + interpreter_->ModifyGraphWithDelegate(delegate_)); + TF_LITE_ENSURE_STATUS(*GetDelegateApplicationStatus()); ++num_applied_delegates_; } else { auto* delegate_providers = tflite::KernelTestDelegateProviders::Get(); @@ -292,8 +294,9 @@ TfLiteStatus SingleOpModel::ApplyDelegate() { for (auto& one : delegate_providers->CreateAllDelegates()) { // The raw ptr always points to the actual TfLiteDegate object. auto* delegate_raw_ptr = one.delegate.get(); - TF_LITE_ENSURE_STATUS( + SetDelegateApplicationStatus( interpreter_->ModifyGraphWithDelegate(std::move(one.delegate))); + TF_LITE_ENSURE_STATUS(*GetDelegateApplicationStatus()); // Note: 'delegate_' is always set to the last successfully applied one. delegate_ = delegate_raw_ptr; ++num_applied_delegates_; diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index c4c18fb3eef57f..710ab60d0e28e0 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -29,6 +29,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -747,6 +748,13 @@ class SingleOpModel { int CountOpsExecutedByCpuKernel(); int CountNumberOfDelegatedPartitions() const; int GetNumberOfAppliedDelegates() const { return num_applied_delegates_; } + // Return the most recent return status of ApplyDelegate. + std::optional GetDelegateApplicationStatus() const { + return delegate_application_status_; + } + void SetDelegateApplicationStatus(std::optional status) { + delegate_application_status_ = status; + } // Tell TF Lite runtime to apply default delegates (i.e. XNNPACK delegate) // when handling this op-level model. @@ -1082,6 +1090,7 @@ class SingleOpModel { std::vector> tensors_; std::vector> buffers_; TfLiteDelegate* delegate_ = nullptr; // not own the memory. + std::optional delegate_application_status_ = std::nullopt; std::vector> input_shapes_; int num_applied_delegates_ = 0; bool allow_fp32_relax_to_fp16_ = false; diff --git a/tensorflow/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD index 54920ef71dc625..03b5438a973ee9 100644 --- a/tensorflow/lite/profiling/BUILD +++ b/tensorflow/lite/profiling/BUILD @@ -194,6 +194,8 @@ cc_library( copts = common_copts, deps = [ "//tensorflow/core/util:stats_calculator_portable", + "//tensorflow/lite/profiling/proto:profiling_info_cc_proto", + "//tensorflow/lite/tools:logging", ], ) @@ -202,6 +204,7 @@ cc_test( srcs = ["profile_summary_formatter_test.cc"], deps = [ ":profile_summary_formatter", + "//tensorflow/lite/profiling/proto:profiling_info_cc_proto", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/profiling/profile_summarizer.cc b/tensorflow/lite/profiling/profile_summarizer.cc index 4bbf4e403a2f2d..f8e461f2baebea 100644 --- a/tensorflow/lite/profiling/profile_summarizer.cc +++ b/tensorflow/lite/profiling/profile_summarizer.cc @@ -201,6 +201,8 @@ void ProfileSummarizer::ProcessProfiles( if (delegate_internal_total_us > 0) { delegate_stats_calculator_->UpdateRunTotalUs(delegate_internal_total_us); } + + SetSubgraphNameMap(interpreter); } tensorflow::StatsCalculator* ProfileSummarizer::GetStatsCalculator( diff --git a/tensorflow/lite/profiling/profile_summarizer.h b/tensorflow/lite/profiling/profile_summarizer.h index 3007440d680159..986bb691c18aee 100644 --- a/tensorflow/lite/profiling/profile_summarizer.h +++ b/tensorflow/lite/profiling/profile_summarizer.h @@ -45,13 +45,13 @@ class ProfileSummarizer { // Returns a string detailing the accumulated runtime stats in the format of // summary_formatter_. std::string GetOutputString() { - return summary_formatter_->GetOutputString(stats_calculator_map_, - *delegate_stats_calculator_); + return summary_formatter_->GetOutputString( + stats_calculator_map_, *delegate_stats_calculator_, subgraph_name_map_); } std::string GetShortSummary() { - return summary_formatter_->GetShortSummary(stats_calculator_map_, - *delegate_stats_calculator_); + return summary_formatter_->GetShortSummary( + stats_calculator_map_, *delegate_stats_calculator_, subgraph_name_map_); } tensorflow::StatsCalculator* GetStatsCalculator(uint32_t subgraph_index); @@ -73,6 +73,17 @@ class ProfileSummarizer { // Summary formatter for customized output formats. std::shared_ptr summary_formatter_; + + std::map subgraph_name_map_; + + void SetSubgraphNameMap(const tflite::Interpreter& interpreter) { + subgraph_name_map_.clear(); + for (int subgraph_index = 0; subgraph_index < interpreter.subgraphs_size(); + ++subgraph_index) { + subgraph_name_map_[subgraph_index] = + interpreter.subgraph(subgraph_index)->GetName(); + } + } }; } // namespace profiling diff --git a/tensorflow/lite/profiling/profile_summary_formatter.cc b/tensorflow/lite/profiling/profile_summary_formatter.cc index 5c7bea2c279e11..31f235c999351a 100644 --- a/tensorflow/lite/profiling/profile_summary_formatter.cc +++ b/tensorflow/lite/profiling/profile_summary_formatter.cc @@ -15,10 +15,20 @@ limitations under the License. #include "tensorflow/lite/profiling/profile_summary_formatter.h" +#include +#include +#include #include #include +#include +#include #include #include +#include +#include + +#include "tensorflow/lite/profiling/proto/profiling_info.pb.h" +#include "tensorflow/lite/tools/logging.h" namespace tflite { namespace profiling { @@ -26,35 +36,47 @@ namespace profiling { std::string ProfileSummaryDefaultFormatter::GetOutputString( const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) const { + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const { return GenerateReport("profile", /*include_output_string*/ true, - stats_calculator_map, delegate_stats_calculator); + stats_calculator_map, delegate_stats_calculator, + subgraph_name_map); } std::string ProfileSummaryDefaultFormatter::GetShortSummary( const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) const { + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const { return GenerateReport("summary", /*include_output_string*/ false, - stats_calculator_map, delegate_stats_calculator); + stats_calculator_map, delegate_stats_calculator, + subgraph_name_map); } std::string ProfileSummaryDefaultFormatter::GenerateReport( const std::string& tag, bool include_output_string, const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) const { + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const { std::stringstream stream; bool has_non_primary_graph = (stats_calculator_map.size() - stats_calculator_map.count(0)) > 0; for (const auto& stats_calc : stats_calculator_map) { auto subgraph_index = stats_calc.first; auto subgraph_stats = stats_calc.second.get(); + std::string subgraph_name = ""; + if (subgraph_name_map.find(subgraph_index) != subgraph_name_map.end()) { + subgraph_name = subgraph_name_map.at(subgraph_index); + } + if (has_non_primary_graph) { if (subgraph_index == 0) { - stream << "Primary graph " << tag << ":" << std::endl; + stream << "Primary graph (name: " << subgraph_name << ") " << tag << ":" + << std::endl; } else { - stream << "Subgraph (index: " << subgraph_index << ") " << tag << ":" + stream << "Subgraph (index: " << subgraph_index + << ", name: " << subgraph_name << ") " << tag << ":" << std::endl; } } @@ -62,7 +84,8 @@ std::string ProfileSummaryDefaultFormatter::GenerateReport( stream << subgraph_stats->GetOutputString(); } if (subgraph_index != 0) { - stream << "Subgraph (index: " << subgraph_index << ") "; + stream << "Subgraph (index: " << subgraph_index + << ", name: " << subgraph_name << ") "; } stream << subgraph_stats->GetShortSummary() << std::endl; } @@ -78,6 +101,25 @@ std::string ProfileSummaryDefaultFormatter::GenerateReport( return stream.str(); } +void ProfileSummaryDefaultFormatter::HandleOutput( + const std::string& init_output, const std::string& run_output, + std::string output_file_path) const { + std::ofstream output_file(output_file_path); + std::ostream* output_stream = nullptr; + if (output_file.good()) { + output_stream = &output_file; + } + if (!init_output.empty()) { + WriteOutput("Profiling Info for Benchmark Initialization:", init_output, + output_stream == nullptr ? &TFLITE_LOG(INFO) : output_stream); + } + if (!run_output.empty()) { + WriteOutput( + "Operator-wise Profiling Info for Regular Benchmark Runs:", run_output, + output_stream == nullptr ? &TFLITE_LOG(INFO) : output_stream); + } +} + tensorflow::StatSummarizerOptions ProfileSummaryDefaultFormatter::GetStatSummarizerOptions() const { auto options = tensorflow::StatSummarizerOptions(); @@ -95,5 +137,172 @@ ProfileSummaryCSVFormatter::GetStatSummarizerOptions() const { return options; } +std::vector +ProfileSummaryProtoFormatter::GetDetailsSortedByRunOrder( + const tensorflow::StatsCalculator* stats_calculator) const { + std::vector details; + std::map unsorted_details = + stats_calculator->GetDetails(); + + std::priority_queue< + std::pair> + sorted_list; + const int num_nodes = unsorted_details.size(); + for (const auto& det : unsorted_details) { + const tensorflow::StatsCalculator::Detail* detail = &(det.second); + std::stringstream stream_for_sort; + stream_for_sort << std::setw(20) << std::right << std::setprecision(10) + << std::fixed; + stream_for_sort << num_nodes - detail->run_order; + sorted_list.emplace(stream_for_sort.str(), detail); + } + + while (!sorted_list.empty()) { + auto entry = sorted_list.top(); + sorted_list.pop(); + details.push_back(*entry.second); + } + return details; +} + +void ProfileSummaryProtoFormatter::GenerateOpProfileDataFromDetail( + const tensorflow::StatsCalculator::Detail* detail, + const tensorflow::StatsCalculator* stats_calculator, + OpProfileData* const op_profile_data) const { + if (detail == nullptr) { + return; + } + + op_profile_data->set_node_type(detail->type); + OpProfilingStat* inference_stat = + op_profile_data->mutable_inference_microseconds(); + inference_stat->set_first(detail->elapsed_time.first()); + inference_stat->set_last(detail->elapsed_time.newest()); + inference_stat->set_avg(detail->elapsed_time.avg()); + inference_stat->set_stddev(detail->elapsed_time.std_deviation()); + inference_stat->set_variance(detail->elapsed_time.variance()); + inference_stat->set_min(detail->elapsed_time.min()); + inference_stat->set_max(detail->elapsed_time.max()); + inference_stat->set_sum(detail->elapsed_time.sum()); + inference_stat->set_count(detail->elapsed_time.count()); + + OpProfilingStat* memory_stat = op_profile_data->mutable_mem_kb(); + memory_stat->set_first(detail->mem_used.first() / 1000.0); + memory_stat->set_last(detail->mem_used.newest() / 1000.0); + memory_stat->set_avg(detail->mem_used.avg() / 1000.0); + memory_stat->set_stddev(detail->mem_used.std_deviation() / 1000.0); + memory_stat->set_variance(detail->mem_used.variance() / 1000000.0); + memory_stat->set_min(detail->mem_used.min() / 1000.0); + memory_stat->set_max(detail->mem_used.max() / 1000.0); + memory_stat->set_sum(detail->mem_used.sum() / 1000.0); + memory_stat->set_count(detail->mem_used.count()); + + op_profile_data->set_times_called(detail->times_called / + stats_calculator->num_runs()); + op_profile_data->set_name(detail->name); + op_profile_data->set_run_order(detail->run_order); +} + +void ProfileSummaryProtoFormatter::GenerateSubGraphProfilingData( + const tensorflow::StatsCalculator* stats_calculator, int subgraph_index, + const std::map& subgraph_name_map, + SubGraphProfilingData* const sub_graph_profiling_data) const { + sub_graph_profiling_data->set_subgraph_index(subgraph_index); + + std::string subgraph_name = ""; + if (subgraph_name_map.find(subgraph_index) != subgraph_name_map.end()) { + subgraph_name = subgraph_name_map.at(subgraph_index); + } + sub_graph_profiling_data->set_subgraph_name(subgraph_name); + + for (tensorflow::StatsCalculator::Detail& detail : + GetDetailsSortedByRunOrder(stats_calculator)) { + OpProfileData* const op_profile_data = + sub_graph_profiling_data->add_per_op_profiles(); + GenerateOpProfileDataFromDetail(&detail, stats_calculator, op_profile_data); + } +} + +void ProfileSummaryProtoFormatter::GenerateDelegateProfilingData( + const tensorflow::StatsCalculator* stats_calculator, + DelegateProfilingData* const delegate_profiling_data) const { + for (const tensorflow::StatsCalculator::Detail& detail : + GetDetailsSortedByRunOrder(stats_calculator)) { + OpProfileData* const op_profile_data = + delegate_profiling_data->add_per_op_profiles(); + GenerateOpProfileDataFromDetail(&detail, stats_calculator, op_profile_data); + } +} + +std::string ProfileSummaryProtoFormatter::GetShortSummary( + const std::map>& + stats_calculator_map, + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const { + TFLITE_LOG(ERROR) << "GetShortSummary is not supported for proto formatter."; + return ""; +} + +std::string ProfileSummaryProtoFormatter::GetOutputString( + const std::map>& + stats_calculator_map, + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const { + ModelProfilingData model_profiling_data; + for (const auto& stats_calc : stats_calculator_map) { + auto subgraph_index = stats_calc.first; + tensorflow::StatsCalculator* subgraph_stats = stats_calc.second.get(); + SubGraphProfilingData* const sub_graph_profiling_data = + model_profiling_data.add_subgraph_profiles(); + GenerateSubGraphProfilingData(subgraph_stats, subgraph_index, + subgraph_name_map, sub_graph_profiling_data); + } + + if (delegate_stats_calculator.num_runs() > 0) { + DelegateProfilingData* const delegate_profiling_data = + model_profiling_data.add_delegate_profiles(); + GenerateDelegateProfilingData(&delegate_stats_calculator, + delegate_profiling_data); + } + + return model_profiling_data.SerializeAsString(); +} + +tensorflow::StatSummarizerOptions +ProfileSummaryProtoFormatter::GetStatSummarizerOptions() const { + auto options = tensorflow::StatSummarizerOptions(); + // Summary will be manually handled per subgraphs in order to keep the + // compatibility. + options.show_summary = false; + options.show_memory = false; + return options; +} + +void ProfileSummaryProtoFormatter::HandleOutput( + const std::string& init_output, const std::string& run_output, + std::string output_file_path) const { + std::ofstream output_file(output_file_path, std::ios_base::binary); + std::ostream* output_stream = nullptr; + if (output_file.good()) { + output_stream = &output_file; + } + + BenchmarkProfilingData benchmark_profiling_data; + if (!init_output.empty()) { + benchmark_profiling_data.mutable_init_profile()->ParseFromString( + init_output); + } + if (!run_output.empty()) { + benchmark_profiling_data.mutable_runtime_profile()->ParseFromString( + run_output); + } + + if (output_stream == nullptr) { + TFLITE_LOG(INFO) << benchmark_profiling_data.DebugString(); + } else { + benchmark_profiling_data.SerializeToOstream(output_stream); + } +} + } // namespace profiling } // namespace tflite diff --git a/tensorflow/lite/profiling/profile_summary_formatter.h b/tensorflow/lite/profiling/profile_summary_formatter.h index 9c7a13530bbd93..62514eafecda93 100644 --- a/tensorflow/lite/profiling/profile_summary_formatter.h +++ b/tensorflow/lite/profiling/profile_summary_formatter.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARY_FORMATTER_H_ #define TENSORFLOW_LITE_PROFILING_PROFILE_SUMMARY_FORMATTER_H_ +#include +#include +#include #include #include #include @@ -23,7 +26,9 @@ limitations under the License. #include #include +#include "tensorflow/core/util/stat_summarizer_options.h" #include "tensorflow/core/util/stats_calculator.h" +#include "tensorflow/lite/profiling/proto/profiling_info.pb.h" namespace tflite { namespace profiling { @@ -31,54 +36,110 @@ namespace profiling { // Formats the profile summary in a certain way. class ProfileSummaryFormatter { public: - ProfileSummaryFormatter() {} + ProfileSummaryFormatter() = default; virtual ~ProfileSummaryFormatter() {} // Returns a string detailing the accumulated runtime stats in StatsCalculator // of ProfileSummarizer. virtual std::string GetOutputString( const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) const = 0; + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const = 0; // Returns a string detailing the short summary of the accumulated runtime // stats in StatsCalculator of ProfileSummarizer. virtual std::string GetShortSummary( const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) const = 0; + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const = 0; virtual tensorflow::StatSummarizerOptions GetStatSummarizerOptions() const = 0; + virtual void HandleOutput(const std::string& init_output, + const std::string& run_output, + std::string output_file_path) const = 0; }; class ProfileSummaryDefaultFormatter : public ProfileSummaryFormatter { public: - ProfileSummaryDefaultFormatter() {} + ProfileSummaryDefaultFormatter() = default; ~ProfileSummaryDefaultFormatter() override {} std::string GetOutputString( const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) - const override; + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const override; std::string GetShortSummary( const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) - const override; + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const override; tensorflow::StatSummarizerOptions GetStatSummarizerOptions() const override; + void HandleOutput(const std::string& init_output, + const std::string& run_output, + std::string output_file_path) const override; private: std::string GenerateReport( const std::string& tag, bool include_output_string, const std::map>& stats_calculator_map, - const tensorflow::StatsCalculator& delegate_stats_calculator) const; + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const; + void WriteOutput(const std::string& header, const std::string& data, + std::ostream* stream) const { + (*stream) << header << std::endl; + (*stream) << data << std::endl; + } }; class ProfileSummaryCSVFormatter : public ProfileSummaryDefaultFormatter { public: - ProfileSummaryCSVFormatter() {} + ProfileSummaryCSVFormatter() = default; tensorflow::StatSummarizerOptions GetStatSummarizerOptions() const override; }; +class ProfileSummaryProtoFormatter : public ProfileSummaryFormatter { + public: + std::string GetOutputString( + const std::map>& + stats_calculator_map, + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const override; + std::string GetShortSummary( + const std::map>& + stats_calculator_map, + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const override; + tensorflow::StatSummarizerOptions GetStatSummarizerOptions() const override; + void HandleOutput(const std::string& init_output, + const std::string& run_output, + std::string output_file_path) const override; + + private: + std::string GenerateReport( + const std::string& tag, bool include_output_string, + const std::map>& + stats_calculator_map, + const tensorflow::StatsCalculator& delegate_stats_calculator, + const std::map& subgraph_name_map) const; + void GenerateSubGraphProfilingData( + const tensorflow::StatsCalculator* stats_calculator, int subgraph_index, + const std::map& subgraph_name_map, + SubGraphProfilingData* sub_graph_profiling_data) const; + + void GenerateDelegateProfilingData( + const tensorflow::StatsCalculator* stats_calculator, + DelegateProfilingData* delegate_profiling_data) const; + + void GenerateOpProfileDataFromDetail( + const tensorflow::StatsCalculator::Detail* detail, + const tensorflow::StatsCalculator* stats_calculator, + OpProfileData* op_profile_data) const; + + std::vector GetDetailsSortedByRunOrder( + const tensorflow::StatsCalculator* stats_calculator) const; +}; + } // namespace profiling } // namespace tflite diff --git a/tensorflow/lite/profiling/profile_summary_formatter_test.cc b/tensorflow/lite/profiling/profile_summary_formatter_test.cc index eefd35667e3b2a..d9f26e0b729bc7 100644 --- a/tensorflow/lite/profiling/profile_summary_formatter_test.cc +++ b/tensorflow/lite/profiling/profile_summary_formatter_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/profiling/profile_summary_formatter.h" +#include +#include #include #include #include @@ -21,6 +23,7 @@ limitations under the License. #include #include #include "absl/strings/match.h" +#include "tensorflow/lite/profiling/proto/profiling_info.pb.h" namespace tflite { namespace profiling { @@ -46,7 +49,7 @@ TEST(SummaryWriterTest, EmptyOutputString) { ProfileSummaryDefaultFormatter writer; std::string output = writer.GetOutputString( std::map>(), - tensorflow::StatsCalculator(writer.GetStatSummarizerOptions())); + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), {}); EXPECT_EQ(output.size(), 0); } @@ -54,7 +57,7 @@ TEST(SummaryWriterTest, EmptyShortSummary) { ProfileSummaryDefaultFormatter writer; std::string output = writer.GetShortSummary( std::map>(), - tensorflow::StatsCalculator(writer.GetStatSummarizerOptions())); + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), {}); EXPECT_EQ(output.size(), 0); } @@ -66,7 +69,7 @@ TEST(SummaryWriterTest, SingleSubgraphOutputString) { writer.GetStatSummarizerOptions()); std::string output = writer.GetOutputString( stats_calculator_map, - tensorflow::StatsCalculator(writer.GetStatSummarizerOptions())); + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), {}); ASSERT_TRUE(absl::StrContains(output, "Run Order")); ASSERT_TRUE(absl::StrContains(output, "Top by Computation Time")); ASSERT_TRUE(!absl::StrContains(output, "Top by Memory Use")); @@ -85,7 +88,8 @@ TEST(SummaryWriterTest, SingleSubgraphShortSummary) { writer.GetStatSummarizerOptions()); std::string output = writer.GetShortSummary( stats_calculator_map, - tensorflow::StatsCalculator(writer.GetStatSummarizerOptions())); + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), + {{0, "Primary graph"}}); ASSERT_TRUE(!absl::StrContains(output, "Run Order")); ASSERT_TRUE(!absl::StrContains(output, "Top by Computation Time")); ASSERT_TRUE(!absl::StrContains(output, "Top by Memory Use")); @@ -106,12 +110,251 @@ TEST(SummaryWriterTest, MultiSubgraphOutputString) { writer.GetStatSummarizerOptions()); std::string output = writer.GetOutputString( stats_calculator_map, - tensorflow::StatsCalculator(writer.GetStatSummarizerOptions())); + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), + {{0, "Primary graph"}, {1, "Subgraph 1"}}); ASSERT_TRUE(absl::StrContains(output, "Primary graph")); ASSERT_TRUE(absl::StrContains(output, "Subgraph")); ASSERT_TRUE(!absl::StrContains(output, "Delegate internal")); } +TEST(SummaryWriterTest, MultiSubgraphOutputStringForProto) { + ProfileSummaryProtoFormatter writer; + std::map> + stats_calculator_map; + stats_calculator_map[0] = std::make_unique( + writer.GetStatSummarizerOptions()); + std::string kernel_name_1 = "Kernel 1"; + std::string kernel_name_2 = "Kernel 2"; + std::string kernel_name_3 = "Kernel 3"; + + std::string op_name_1 = "Convolution"; + std::string op_name_2 = "Reshape"; + std::string op_name_3 = "Convolution"; + stats_calculator_map[0]->AddNodeStats(kernel_name_1, op_name_1, 1, 10, 10000); + stats_calculator_map[0]->AddNodeStats(kernel_name_1, op_name_1, 1, 20, 20000); + stats_calculator_map[0]->AddNodeStats(kernel_name_2, op_name_2, 2, 15, 10000); + stats_calculator_map[0]->UpdateRunTotalUs(25); + stats_calculator_map[1] = std::make_unique( + writer.GetStatSummarizerOptions()); + stats_calculator_map[1]->AddNodeStats(kernel_name_3, op_name_3, 3, 10, 10000); + stats_calculator_map[1]->UpdateRunTotalUs(10); + + std::string output = writer.GetOutputString( + stats_calculator_map, + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), + {{0, "Primary graph"}, {1, "Subgraph 1"}}); + ModelProfilingData model_profiling_data; + model_profiling_data.ParseFromString(output); + ASSERT_TRUE(absl::StrContains(output, "Primary graph")); + ASSERT_TRUE(absl::StrContains(output, "Subgraph")); + ASSERT_TRUE(!absl::StrContains(output, "Delegate internal")); + ASSERT_EQ(model_profiling_data.subgraph_profiles().size(), 2); + ASSERT_EQ(model_profiling_data.subgraph_profiles(0).subgraph_name(), + "Primary graph"); + ASSERT_EQ(model_profiling_data.subgraph_profiles(0).per_op_profiles().size(), + 2); + + OpProfileData op_profile_data_1; + op_profile_data_1.set_node_type(op_name_1); + OpProfilingStat* inference_microseconds_stat_1 = + op_profile_data_1.mutable_inference_microseconds(); + inference_microseconds_stat_1->set_first(10); + inference_microseconds_stat_1->set_last(20); + inference_microseconds_stat_1->set_max(20); + inference_microseconds_stat_1->set_min(10); + inference_microseconds_stat_1->set_avg(15); + inference_microseconds_stat_1->set_stddev(5); + inference_microseconds_stat_1->set_variance(25); + inference_microseconds_stat_1->set_sum(30); + inference_microseconds_stat_1->set_count(2); + OpProfilingStat* memory_stat_1 = op_profile_data_1.mutable_mem_kb(); + memory_stat_1->set_first(10); + memory_stat_1->set_last(20); + memory_stat_1->set_max(20); + memory_stat_1->set_min(10); + memory_stat_1->set_avg(15); + memory_stat_1->set_stddev(5); + memory_stat_1->set_variance(25); + memory_stat_1->set_sum(30); + memory_stat_1->set_count(2); + op_profile_data_1.set_name(kernel_name_1); + op_profile_data_1.set_run_order(1); + op_profile_data_1.set_times_called(2); + EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(0), + testing::EqualsProto(op_profile_data_1)); + + OpProfileData op_profile_data_2; + op_profile_data_2.set_node_type(op_name_2); + OpProfilingStat* inference_microseconds_stat_2 = + op_profile_data_2.mutable_inference_microseconds(); + inference_microseconds_stat_2->set_first(15); + inference_microseconds_stat_2->set_last(15); + inference_microseconds_stat_2->set_max(15); + inference_microseconds_stat_2->set_min(15); + inference_microseconds_stat_2->set_avg(15); + inference_microseconds_stat_2->set_stddev(0); + inference_microseconds_stat_2->set_variance(0); + inference_microseconds_stat_2->set_sum(15); + inference_microseconds_stat_2->set_count(1); + OpProfilingStat* memory_stat_2 = op_profile_data_2.mutable_mem_kb(); + memory_stat_2->set_first(10); + memory_stat_2->set_last(10); + memory_stat_2->set_max(10); + memory_stat_2->set_min(10); + memory_stat_2->set_avg(10); + memory_stat_2->set_stddev(0); + memory_stat_2->set_variance(0); + memory_stat_2->set_sum(10); + memory_stat_2->set_count(1); + op_profile_data_2.set_times_called(1); + op_profile_data_2.set_name(kernel_name_2); + op_profile_data_2.set_run_order(2); + + EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(1), + testing::EqualsProto(op_profile_data_2)); + + ASSERT_EQ(model_profiling_data.subgraph_profiles(1).subgraph_name(), + "Subgraph 1"); + ASSERT_EQ(model_profiling_data.subgraph_profiles(1).per_op_profiles().size(), + 1); + + OpProfileData op_profile_data_3; + op_profile_data_3.set_node_type(op_name_3); + OpProfilingStat* inference_microseconds_stat_3 = + op_profile_data_3.mutable_inference_microseconds(); + inference_microseconds_stat_3->set_first(10); + inference_microseconds_stat_3->set_last(10); + inference_microseconds_stat_3->set_max(10); + inference_microseconds_stat_3->set_min(10); + inference_microseconds_stat_3->set_avg(10); + inference_microseconds_stat_3->set_stddev(0); + inference_microseconds_stat_3->set_variance(0); + inference_microseconds_stat_3->set_sum(10); + inference_microseconds_stat_3->set_count(1); + OpProfilingStat* memory_stat_3 = op_profile_data_3.mutable_mem_kb(); + memory_stat_3->set_first(10); + memory_stat_3->set_last(10); + memory_stat_3->set_max(10); + memory_stat_3->set_min(10); + memory_stat_3->set_avg(10); + memory_stat_3->set_stddev(0); + memory_stat_3->set_variance(0); + memory_stat_3->set_sum(10); + memory_stat_3->set_count(1); + op_profile_data_3.set_times_called(1); + op_profile_data_3.set_name(kernel_name_3); + op_profile_data_3.set_run_order(3); + EXPECT_THAT(model_profiling_data.subgraph_profiles(1).per_op_profiles(0), + testing::EqualsProto(op_profile_data_3)); +} + +TEST(SummaryWriterTest, MultiSubgraphHandleOutputForProto) { + ProfileSummaryProtoFormatter writer; + + ModelProfilingData model_profiling_data_run; + SubGraphProfilingData* subgraph_profiling_data = + model_profiling_data_run.add_subgraph_profiles(); + subgraph_profiling_data->set_subgraph_name("Primary graph"); + OpProfileData* op_profile_data_1 = + subgraph_profiling_data->add_per_op_profiles(); + op_profile_data_1->set_node_type("Convolution"); + OpProfilingStat* inference_stat_1 = + op_profile_data_1->mutable_inference_microseconds(); + inference_stat_1->set_first(10); + inference_stat_1->set_avg(10); + OpProfilingStat* mem_stat_1 = op_profile_data_1->mutable_mem_kb(); + mem_stat_1->set_first(10); + mem_stat_1->set_avg(10); + op_profile_data_1->set_times_called(1); + op_profile_data_1->set_name("Kernel 1"); + op_profile_data_1->set_run_order(1); + OpProfileData* op_profile_data_2 = + subgraph_profiling_data->add_per_op_profiles(); + op_profile_data_2->set_node_type("Reshape"); + OpProfilingStat* inference_stat_2 = + op_profile_data_2->mutable_inference_microseconds(); + inference_stat_2->set_first(15); + inference_stat_2->set_avg(15); + OpProfilingStat* mem_stat_2 = op_profile_data_2->mutable_mem_kb(); + mem_stat_2->set_first(10); + mem_stat_2->set_avg(10); + op_profile_data_2->set_times_called(1); + op_profile_data_2->set_name("Kernel 2"); + op_profile_data_2->set_run_order(2); + SubGraphProfilingData* subgraph_profiling_data_1 = + model_profiling_data_run.add_subgraph_profiles(); + subgraph_profiling_data_1->set_subgraph_name("Subgraph 1"); + OpProfileData* op_profile_data_3 = + subgraph_profiling_data_1->add_per_op_profiles(); + op_profile_data_3->set_node_type("Convolution"); + OpProfilingStat* inference_stat_3 = + op_profile_data_3->mutable_inference_microseconds(); + inference_stat_3->set_first(10); + inference_stat_3->set_avg(10); + OpProfilingStat* mem_stat_3 = op_profile_data_3->mutable_mem_kb(); + mem_stat_3->set_first(10); + mem_stat_3->set_avg(10); + op_profile_data_3->set_times_called(1); + op_profile_data_3->set_name("Kernel 3"); + op_profile_data_3->set_run_order(3); + DelegateProfilingData* delegate_profiling_data = + model_profiling_data_run.add_delegate_profiles(); + OpProfileData* op_profile_data_4 = + delegate_profiling_data->add_per_op_profiles(); + op_profile_data_4->set_node_type("Convolution"); + OpProfilingStat* inference_stat_4 = + op_profile_data_4->mutable_inference_microseconds(); + inference_stat_4->set_first(10); + inference_stat_4->set_avg(10); + OpProfilingStat* mem_stat_4 = op_profile_data_4->mutable_mem_kb(); + mem_stat_4->set_first(10); + mem_stat_4->set_avg(10); + op_profile_data_4->set_times_called(1); + op_profile_data_4->set_name("Kernel 4"); + op_profile_data_4->set_run_order(4); + + ModelProfilingData model_profiling_data_init; + SubGraphProfilingData* subgraph_profiling_data_init = + model_profiling_data_init.add_subgraph_profiles(); + subgraph_profiling_data_init->set_subgraph_name("Primary graph"); + OpProfileData* op_profile_data_init_1 = + subgraph_profiling_data_init->add_per_op_profiles(); + op_profile_data_init_1->set_node_type("Convolution"); + OpProfilingStat* inference_stat_init_1 = + op_profile_data_init_1->mutable_inference_microseconds(); + inference_stat_init_1->set_first(10); + inference_stat_init_1->set_avg(10); + op_profile_data_init_1->set_times_called(1); + OpProfilingStat* mem_stat_init_1 = op_profile_data_init_1->mutable_mem_kb(); + mem_stat_init_1->set_first(10); + mem_stat_init_1->set_avg(10); + op_profile_data_init_1->set_name("ModifyGraphWithDelegate"); + op_profile_data_init_1->set_run_order(1); + +#ifdef __ANDROID__ + std::string file_name = "/data/local/tmp/test_file.proto"; +#else + std::string file_name = "/tmp/test_file.proto"; +#endif + + writer.HandleOutput(model_profiling_data_init.SerializeAsString(), + model_profiling_data_run.SerializeAsString(), file_name); + + std::ifstream file(file_name, std::ios::binary); + + ASSERT_TRUE(file.good()); + + BenchmarkProfilingData benchmark_profiling_data; + benchmark_profiling_data.ParseFromIstream(&file); + file.close(); + + ASSERT_TRUE(benchmark_profiling_data.model_name().empty()); + EXPECT_THAT(benchmark_profiling_data.init_profile(), + testing::EqualsProto(model_profiling_data_init)); + EXPECT_THAT(benchmark_profiling_data.runtime_profile(), + testing::EqualsProto(model_profiling_data_run)); +} + TEST(SummaryWriterTest, MultiSubgraphShortSummary) { ProfileSummaryDefaultFormatter writer; std::map> @@ -122,7 +365,8 @@ TEST(SummaryWriterTest, MultiSubgraphShortSummary) { writer.GetStatSummarizerOptions()); std::string output = writer.GetShortSummary( stats_calculator_map, - tensorflow::StatsCalculator(writer.GetStatSummarizerOptions())); + tensorflow::StatsCalculator(writer.GetStatSummarizerOptions()), + {{0, "Primary graph"}, {1, "Subgraph 1"}}); ASSERT_TRUE(absl::StrContains(output, "Primary graph")); ASSERT_TRUE(absl::StrContains(output, "Subgraph")); ASSERT_TRUE(!absl::StrContains(output, "Delegate internal")); @@ -135,7 +379,7 @@ TEST(SummaryWriterTest, DelegationOutputString) { delegate_stats_calculator.UpdateRunTotalUs(1); std::string output = writer.GetOutputString( std::map>(), - delegate_stats_calculator); + delegate_stats_calculator, {}); ASSERT_TRUE(!absl::StrContains(output, "Primary graph")); ASSERT_TRUE(!absl::StrContains(output, "Subgraph")); ASSERT_TRUE(absl::StrContains(output, "Delegate internal")); @@ -148,7 +392,7 @@ TEST(SummaryWriterTest, DelegationShortSummary) { delegate_stats_calculator.UpdateRunTotalUs(1); std::string output = writer.GetShortSummary( std::map>(), - delegate_stats_calculator); + delegate_stats_calculator, {}); ASSERT_TRUE(!absl::StrContains(output, "Primary graph")); ASSERT_TRUE(!absl::StrContains(output, "Subgraph")); ASSERT_TRUE(absl::StrContains(output, "Delegate internal")); diff --git a/tensorflow/lite/profiling/proto/BUILD b/tensorflow/lite/profiling/proto/BUILD new file mode 100644 index 00000000000000..5e3160b318bf8e --- /dev/null +++ b/tensorflow/lite/profiling/proto/BUILD @@ -0,0 +1,41 @@ +# Placeholder: load py_proto_library +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_proto_library", +) +# copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +proto_library( + name = "profiling_info_proto", + srcs = ["profiling_info.proto"], + compatible_with = get_compatible_with_portable(), + visibility = ["//visibility:public"], +) + +cc_proto_library( + name = "profiling_info_cc_proto", + compatible_with = get_compatible_with_portable(), + deps = [":profiling_info_proto"], +) + +tf_proto_library( + name = "profiling_info", # bzl adds _py + srcs = ["profiling_info.proto"], + visibility = ["//visibility:public"], +) + +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "profiling_info_py_pb2", +# api_version = 2, +# compatible_with = get_compatible_with_portable(), +# deps = [":profiling_info_proto"], +# ) +# copybara:uncomment_end diff --git a/tensorflow/lite/profiling/proto/CMakeLists.txt b/tensorflow/lite/profiling/proto/CMakeLists.txt new file mode 100644 index 00000000000000..a0955470db7d6f --- /dev/null +++ b/tensorflow/lite/profiling/proto/CMakeLists.txt @@ -0,0 +1,41 @@ +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +find_package(Protobuf REQUIRED) + +add_library(profiling_info_proto profiling_info.proto) + +list(APPEND proto_generated_files ${CMAKE_CURRENT_BINARY_DIR}/profiling_info.pb.cc ${CMAKE_CURRENT_BINARY_DIR}/profiling_info.pb.h) + +# Generate profiling_info.pb.cc and profiling_info.pb.h from +# profiling_info.proto using protoc. Once the protobuf package version is +# upgraded, we can use protobuf_generate_cpp/protobuf_generate here directly. +add_custom_command( + OUTPUT ${proto_generated_files} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} --proto_path=${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/profiling_info.proto + DEPENDS ${Protobuf_PROTOC_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/profiling_info.proto +) + +set_source_files_properties(${proto_generated_files} PROPERTIES GENERATED TRUE) +target_sources(profiling_info_proto PRIVATE ${proto_generated_files}) +target_link_libraries(profiling_info_proto protobuf::libprotobuf) +target_include_directories(profiling_info_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +# Move all generated proto files to the TFLITE_GENERATED_HEADERS_DIR +add_custom_command( + TARGET profiling_info_proto POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_BINARY_DIR} + ${TFLITE_GENERATED_HEADERS_DIR}/profiling/proto) \ No newline at end of file diff --git a/tensorflow/lite/profiling/proto/profiling_info.proto b/tensorflow/lite/profiling/proto/profiling_info.proto new file mode 100644 index 00000000000000..8116524405dc11 --- /dev/null +++ b/tensorflow/lite/profiling/proto/profiling_info.proto @@ -0,0 +1,63 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless optional by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.profiling; + +option java_multiple_files = true; + +message BenchmarkProfilingData { + optional string model_name = 1; + optional ModelProfilingData init_profile = 2; + optional ModelProfilingData runtime_profile = 3; +} + +message ModelProfilingData { + repeated SubGraphProfilingData subgraph_profiles = 1; + repeated DelegateProfilingData delegate_profiles = 2; +} + +message SubGraphProfilingData { + optional string subgraph_name = 1; + optional int32 subgraph_index = 2; + repeated OpProfileData per_op_profiles = 3; +} + +message DelegateProfilingData { + optional string delegate_name = 1; + repeated OpProfileData per_op_profiles = 2; +} + +message OpProfilingStat { + optional int64 first = 1; + optional int64 last = 2; + optional int64 avg = 3; + optional float stddev = 4; + optional float variance = 5; + optional int64 min = 6; + optional int64 max = 7; + optional int64 sum = 8; + optional int64 count = 9; +} + +message OpProfileData { + optional string node_type = 1; + optional OpProfilingStat inference_microseconds = 2; + optional OpProfilingStat mem_kb = 3; + optional int64 times_called = 4; + optional string name = 5; + optional int64 run_order = 6; +} diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index 4e85652310481e..8ff6d3939d996b 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -197,6 +197,7 @@ py_strict_library( "//tensorflow/compiler/mlir/quantization/tensorflow/python:representative_dataset", "//tensorflow/core:protos_all_py", "//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py", + "//tensorflow/lite/profiling/proto:profiling_info_py", "//tensorflow/lite/python/metrics", "//tensorflow/lite/python/optimize:calibrator", "//tensorflow/lite/tools:flatbuffer_utils", diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index 6155575b4048b5..7bf0f18d68fc24 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -144,28 +144,6 @@ flatbuffer_cc_library( out_prefix = "reflection/", ) -# Schema test to make sure we don't introduce backward incompatible changes -# to schemas. -cc_test( - name = "flatbuffer_compatibility_test", - size = "small", - srcs = ["flatbuffer_compatibility_test.cc"], - data = [ - "schema.fbs", - "schema_v3b.fbs", - ], - tags = [ - "no_oss", - "tflite_not_portable_android", - "tflite_not_portable_ios", - ], - deps = [ - "//tensorflow/core/platform", - "@com_google_googletest//:gtest_main", - "@flatbuffers//:flatc_library", - ], -) - cc_library( name = "schema_utils", hdrs = ["schema_utils.h"], diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index 8c60f8ad012bd8..b08a2d913b6ec7 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -257,6 +257,7 @@ cc_library_with_tflite( cc_library( name = "logging", hdrs = ["logging.h"], + compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings(), ) diff --git a/tensorflow/lite/tools/benchmark/CMakeLists.txt b/tensorflow/lite/tools/benchmark/CMakeLists.txt index fc2a1be282f985..56794382ff45a8 100644 --- a/tensorflow/lite/tools/benchmark/CMakeLists.txt +++ b/tensorflow/lite/tools/benchmark/CMakeLists.txt @@ -45,6 +45,11 @@ list(APPEND TFLITE_BENCHMARK_LIBS tensorflow-lite ) +list(APPEND TFLITE_BENCHMARK_LIBS + profiling_info_proto + protobuf::libprotobuf +) + # TODO(b/171007016): Enable performance options on Windows. if(NOT "${CMAKE_SYSTEM_NAME}" STREQUAL "Windows") list(APPEND TFLITE_BENCHMARK_SRCS @@ -92,6 +97,10 @@ target_compile_options(benchmark_model PRIVATE ${TFLITE_BENCHMARK_CC_OPTIONS} ) +target_include_directories(benchmark_model + PUBLIC + ${CMAKE_BINARY_DIR} +) target_link_libraries(benchmark_model ${TFLITE_BENCHMARK_LIBS} ) diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index f25da51705d6b8..e92d841b9c6a87 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -67,7 +67,24 @@ and the following optional parameters: thus it is preferred to set `max_profiling_buffer_entries` to a large-enough value. +* `op_profiling_output_mode`: `str` (default="stdout") \ + The output mode for the profiling information generated. Requires + `enable_op_profiling` to be `true`. Takes one of the following 3 values: + - `stdout` : Print profiling information to STDOUT. + - `csv` : Print the profiling information in a CSV format. + - `proto` : Print the profiling information in a proto format as specified + in `tensorflow/lite/profiling/proto/profiling_info.proto`. +* `op_profiling_output_file`: `str` (default="") \ + File path to export profile data to. The results are printed to + `stdout` if option is not set. Requires `enable_op_profiling` to be `true` + and the path to include the name of the output file; otherwise results are + printed to `stdout`. + * `profiling_output_csv_file`: `str` (default="") \ + + WARNING: Deprecated, prefer using `op_profiling_output_mode` and + `op_profiling_output_file` instead. + File path to export profile data to as CSV. The results are printed to `stdout` if option is not set. Requires `enable_op_profiling` to be `true` and the path to include the name of the output CSV; otherwise results are diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index d775122fe9c1fc..8fb5b23b7860d9 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -78,6 +78,15 @@ constexpr bool kOpProfilingEnabledDefault = true; constexpr bool kOpProfilingEnabledDefault = false; #endif +// Op profiling output modes. +constexpr char kOpProfilingOutputModeStdout[] = "stdout"; +constexpr char kOpProfilingOutputModeCsv[] = "csv"; +constexpr char kOpProfilingOutputModeProto[] = "proto"; + +const char* kOpProfilingOutputModes[] = {kOpProfilingOutputModeStdout, + kOpProfilingOutputModeCsv, + kOpProfilingOutputModeProto}; + // Dumps ruy profiling events if the ruy profiler is enabled. class RuyProfileListener : public BenchmarkListener { public: @@ -310,10 +319,14 @@ TfLiteStatus PopulateInputLayerInfo( } std::shared_ptr -CreateProfileSummaryFormatter(bool format_as_csv) { - return format_as_csv - ? std::make_shared() - : std::make_shared(); +CreateProfileSummaryFormatter(const std::string& output_mode) { + if (output_mode == kOpProfilingOutputModeCsv) { + return std::make_shared(); + } else if (output_mode == kOpProfilingOutputModeProto) { + return std::make_shared(); + } else { + return std::make_shared(); + } } } // namespace @@ -479,6 +492,11 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { default_params.AddParam( "enable_op_profiling", BenchmarkParam::Create(kOpProfilingEnabledDefault)); + default_params.AddParam( + "op_profiling_output_mode", + BenchmarkParam::Create(kOpProfilingOutputModeStdout)); + default_params.AddParam("op_profiling_output_file", + BenchmarkParam::Create("")); default_params.AddParam("max_profiling_buffer_entries", BenchmarkParam::Create(1024)); default_params.AddParam("allow_dynamic_profiling_buffer_increase", @@ -565,14 +583,21 @@ std::vector BenchmarkTfLiteModel::GetFlags() { CreateFlag("require_full_delegation", ¶ms_, "require delegate to run the entire graph"), CreateFlag("enable_op_profiling", ¶ms_, "enable op profiling"), + CreateFlag( + "op_profiling_output_mode", ¶ms_, + "Output mode for op profiling results. Supported values are: " + "'stdout', 'csv' and 'proto'."), + CreateFlag("op_profiling_output_file", ¶ms_, + "Output file for op profiling results."), CreateFlag("max_profiling_buffer_entries", ¶ms_, "max initial profiling buffer entries"), CreateFlag("allow_dynamic_profiling_buffer_increase", ¶ms_, "allow dynamic increase on profiling buffer entries"), - CreateFlag( - "profiling_output_csv_file", ¶ms_, - "File path to export profile data as CSV, if not set " - "prints to stdout."), + CreateFlag("profiling_output_csv_file", ¶ms_, + "[DEPRECATED: Use op_profiling_output_file and " + "op_profiling_output_mode instead] File path to " + "export profile data as CSV, if not set " + "prints to stdout."), CreateFlag( "print_preinvoke_state", ¶ms_, "print out the interpreter internals just before calling Invoke. The " @@ -650,6 +675,10 @@ void BenchmarkTfLiteModel::LogParams() { "Require full delegation", verbose); LOG_BENCHMARK_PARAM(bool, "enable_op_profiling", "Enable op profiling", verbose); + LOG_BENCHMARK_PARAM(std::string, "op_profiling_output_mode", + "Op profiling output mode.", verbose); + LOG_BENCHMARK_PARAM(std::string, "op_profiling_output_file", + "Op profiling output file.", verbose); LOG_BENCHMARK_PARAM(int32_t, "max_profiling_buffer_entries", "Max initial profiling buffer entries", verbose); LOG_BENCHMARK_PARAM(bool, "allow_dynamic_profiling_buffer_increase", @@ -693,6 +722,31 @@ TfLiteStatus BenchmarkTfLiteModel::ValidateParams() { return kTfLiteError; } + if (params_.Get("enable_op_profiling")) { + bool found = + std::find(std::begin(kOpProfilingOutputModes), + std::end(kOpProfilingOutputModes), + params_.Get("op_profiling_output_mode")) != + std::end(kOpProfilingOutputModes); + + if (!found) { + TFLITE_LOG(ERROR) << "Output mode" + << params_.Get("op_profiling_output_mode") + << " is not supported. Supported values are: 'stdout', " + "'csv' and 'proto'."; + return kTfLiteError; + } + + if (!params_.Get("profiling_output_csv_file").empty()) { + // Backward compatibility for profiling_output_csv_file. + params_.Set("op_profiling_output_mode", + kOpProfilingOutputModeCsv); + params_.Set( + "op_profiling_output_file", + params_.Get("profiling_output_csv_file")); + } + } + return PopulateInputLayerInfo( params_.Get("input_layer"), params_.Get("input_layer_shape"), @@ -1123,9 +1177,9 @@ BenchmarkTfLiteModel::MayCreateProfilingListener() const { return std::unique_ptr(new ProfilingListener( interpreter_.get(), params_.Get("max_profiling_buffer_entries"), params_.Get("allow_dynamic_profiling_buffer_increase"), - params_.Get("profiling_output_csv_file"), + params_.Get("op_profiling_output_file"), CreateProfileSummaryFormatter( - !params_.Get("profiling_output_csv_file").empty()))); + params_.Get("op_profiling_output_mode")))); } TfLiteStatus BenchmarkTfLiteModel::RunImpl() { diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.cc b/tensorflow/lite/tools/benchmark/profiling_listener.cc index eff38b0da05f5d..0099c4f8e5fe19 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.cc +++ b/tensorflow/lite/tools/benchmark/profiling_listener.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/profiling/profile_summarizer.h" #include "tensorflow/lite/tools/logging.h" namespace tflite { @@ -25,13 +26,14 @@ namespace benchmark { ProfilingListener::ProfilingListener( Interpreter* interpreter, uint32_t max_num_initial_entries, - bool allow_dynamic_buffer_increase, const std::string& csv_file_path, + bool allow_dynamic_buffer_increase, const std::string& output_file_path, std::shared_ptr summarizer_formatter) : run_summarizer_(summarizer_formatter), init_summarizer_(summarizer_formatter), - csv_file_path_(csv_file_path), + output_file_path_(output_file_path), interpreter_(interpreter), - profiler_(max_num_initial_entries, allow_dynamic_buffer_increase) { + profiler_(max_num_initial_entries, allow_dynamic_buffer_increase), + summarizer_formatter_(summarizer_formatter) { TFLITE_TOOLS_CHECK(interpreter); interpreter_->SetProfiler(&profiler_); @@ -66,27 +68,9 @@ void ProfilingListener::OnSingleRunEnd() { } void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) { - std::ofstream output_file(csv_file_path_); - std::ostream* output_stream = nullptr; - if (output_file.good()) { - output_stream = &output_file; - } - if (init_summarizer_.HasProfiles()) { - WriteOutput("Profiling Info for Benchmark Initialization:", - init_summarizer_.GetOutputString(), - output_stream == nullptr ? &TFLITE_LOG(INFO) : output_stream); - } - if (run_summarizer_.HasProfiles()) { - WriteOutput("Operator-wise Profiling Info for Regular Benchmark Runs:", - run_summarizer_.GetOutputString(), - output_stream == nullptr ? &TFLITE_LOG(INFO) : output_stream); - } -} - -void ProfilingListener::WriteOutput(const std::string& header, - const string& data, std::ostream* stream) { - (*stream) << header << std::endl; - (*stream) << data << std::endl; + summarizer_formatter_->HandleOutput(init_summarizer_.GetOutputString(), + run_summarizer_.GetOutputString(), + output_file_path_); } } // namespace benchmark diff --git a/tensorflow/lite/tools/benchmark/profiling_listener.h b/tensorflow/lite/tools/benchmark/profiling_listener.h index a9957ddb06b7b1..03869e3df5fe31 100644 --- a/tensorflow/lite/tools/benchmark/profiling_listener.h +++ b/tensorflow/lite/tools/benchmark/profiling_listener.h @@ -32,7 +32,8 @@ class ProfilingListener : public BenchmarkListener { public: ProfilingListener( Interpreter* interpreter, uint32_t max_num_initial_entries, - bool allow_dynamic_buffer_increase, const std::string& csv_file_path = "", + bool allow_dynamic_buffer_increase, + const std::string& output_file_path = "", std::shared_ptr summarizer_formatter = std::make_shared()); @@ -47,13 +48,12 @@ class ProfilingListener : public BenchmarkListener { protected: profiling::ProfileSummarizer run_summarizer_; profiling::ProfileSummarizer init_summarizer_; - std::string csv_file_path_; + std::string output_file_path_; private: - void WriteOutput(const std::string& header, const string& data, - std::ostream* stream); Interpreter* interpreter_; profiling::BufferedProfiler profiler_; + std::shared_ptr summarizer_formatter_; }; } // namespace benchmark diff --git a/tensorflow/lite/tools/cmake/modules/FindProtobuf.cmake b/tensorflow/lite/tools/cmake/modules/FindProtobuf.cmake new file mode 100644 index 00000000000000..3641e8a69e86b0 --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/FindProtobuf.cmake @@ -0,0 +1,16 @@ +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(protobuf) \ No newline at end of file diff --git a/tensorflow/lite/tools/cmake/modules/protobuf.cmake b/tensorflow/lite/tools/cmake/modules/protobuf.cmake new file mode 100644 index 00000000000000..de09cdeda9c370 --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/protobuf.cmake @@ -0,0 +1,45 @@ +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(OverridableFetchContent) + +OverridableFetchContent_Declare( + protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf + # Sync with tensorflow/third_party/protobuf/protobuf.patch + GIT_TAG 90b73ac3f0b10320315c2ca0d03a5a9b095d2f66 + GIT_PROGRESS TRUE + PREFIX "${CMAKE_BINARY_DIR}" + SOURCE_DIR "${CMAKE_BINARY_DIR}/protobuf" +) + +set(protobuf_ABSL_PROVIDER "package" CACHE STRING "" FORCE) +set(protobuf_BUILD_TESTS OFF CACHE BOOL "" FORCE) +set(protobuf_BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) +set(protobuf_INSTALL OFF CACHE BOOL "" FORCE) +set(protobuf_WITH_ZLIB OFF CACHE BOOL "" FORCE) +set(protobuf_BUILD_PROTOC_BINARIES ON CACHE BOOL "" FORCE) + +OverridableFetchContent_GetProperties(protobuf) +if(NOT protobuf_POPULATED) + OverridableFetchContent_Populate(protobuf) +endif() + +set(Protobuf_INCLUDE_DIR "${protobuf_SOURCE_DIR}/src" CACHE INTERNAL "") +set(Protobuf_LIBRARIES protobuf::libprotobuf CACHE INTERNAL "") + +add_subdirectory(${protobuf_SOURCE_DIR} ${protobuf_BINARY_DIR}) + +set(Protobuf_PROTOC_EXECUTABLE protoc CACHE INTERNAL "") diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 75fedf76ac1ed9..cc1ee13ce3de3f 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 5, 24) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 5, 28) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 30880f5a27eaad..88dc5ae24a6833 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,156 +1,1763 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp ---- a/clang/lib/Sema/SemaTemplate.cpp -+++ b/clang/lib/Sema/SemaTemplate.cpp -@@ -1807,8 +1807,6 @@ - // Returns the template parameter list with all default template argument - // information. - static TemplateParameterList *GetTemplateParameterList(TemplateDecl *TD) { -- if (TD->isImplicit()) -- return TD->getTemplateParameters(); - // Make sure we get the template parameter list from the most - // recent declaration, since that is the only one that is guaranteed to - // have all the default template argument information. -@@ -1829,8 +1827,7 @@ - // template friend struct C; - // }; - // template struct S; -- while ((D->isImplicit() || -- D->getFriendObjectKind() != Decl::FriendObjectKind::FOK_None) && -+ while (D->getFriendObjectKind() != Decl::FriendObjectKind::FOK_None && - D->getPreviousDecl()) - D = D->getPreviousDecl(); - return cast(D)->getTemplateParameters(); -diff -ruN --strip-trailing-cr a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp ---- a/clang/lib/Sema/SemaTemplateDeduction.cpp -+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp -@@ -527,8 +527,8 @@ - R->setDefaultArgument( - S.Context, - S.getTrivialTemplateArgumentLoc(Default, QualType(), SourceLocation())); -- if (T->hasTypeConstraint()) { -- auto *C = T->getTypeConstraint(); -+ if (R->hasTypeConstraint()) { -+ auto *C = R->getTypeConstraint(); - R->setTypeConstraint(C->getConceptReference(), - C->getImmediatelyDeclaredConstraint()); - } -@@ -583,53 +583,37 @@ - return TemplateDeductionResult::Success; - - auto NewDeduced = DeducedTemplateArgument(Arg); -- // Provisional resolution for CWG2398: If Arg names a template -- // specialization, then we deduce a synthesized template template parameter -- // based on A, but using the TS's arguments as defaults. -- if (DefaultArguments.size() != 0) { -+ // Provisional resolution for CWG2398: If Arg is also a template template -+ // param, and it names a template specialization, then we deduce a -+ // synthesized template template parameter based on A, but using the TS's -+ // arguments as defaults. -+ if (auto *TempArg = dyn_cast_or_null( -+ Arg.getAsTemplateDecl())) { - assert(Arg.getKind() == TemplateName::Template); -- TemplateDecl *TempArg = Arg.getAsTemplateDecl(); -- TemplateParameterList *As = TempArg->getTemplateParameters(); -- assert(DefaultArguments.size() <= As->size()); -- -- SmallVector Params(As->size()); -- for (unsigned I = 0; I < DefaultArguments.size(); ++I) -- Params[I] = getTemplateParameterWithDefault(S, As->getParam(I), -- DefaultArguments[I]); -- for (unsigned I = DefaultArguments.size(); I < As->size(); ++I) -- Params[I] = As->getParam(I); -- // FIXME: We could unique these, and also the parameters, but we don't -- // expect programs to contain a large enough amount of these deductions -- // for that to be worthwhile. -- auto *TPL = TemplateParameterList::Create( -- S.Context, SourceLocation(), SourceLocation(), Params, -- SourceLocation(), As->getRequiresClause()); -+ assert(!TempArg->isExpandedParameterPack()); - -- TemplateDecl *TD; -- switch (TempArg->getKind()) { -- case Decl::TemplateTemplateParm: { -- auto *A = cast(TempArg); -- assert(!A->isExpandedParameterPack()); -- TD = TemplateTemplateParmDecl::Create( -- S.Context, A->getDeclContext(), SourceLocation(), A->getDepth(), -- A->getPosition(), A->isParameterPack(), A->getIdentifier(), -- A->wasDeclaredWithTypename(), TPL); -- break; +diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp +--- a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp ++++ b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.cpp +@@ -1,744 +0,0 @@ +-//===- AMDGPUSplitModule.cpp ----------------------------------------------===// +-// +-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +-// See https://llvm.org/LICENSE.txt for license information. +-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +-// +-//===----------------------------------------------------------------------===// +-// +-/// \file Implements a module splitting algorithm designed to support the +-/// FullLTO --lto-partitions option for parallel codegen. This is completely +-/// different from the common SplitModule pass, as this system is designed with +-/// AMDGPU in mind. +-/// +-/// The basic idea of this module splitting implementation is the same as +-/// SplitModule: load-balance the module's functions across a set of N +-/// partitions to allow parallel codegen. However, it does it very +-/// differently than the target-agnostic variant: +-/// - Kernels are used as the module's "roots". +-/// They're known entry points on AMDGPU, and everything else is often +-/// internal only. +-/// - Each kernel has a set of dependencies, and when a kernel and its +-/// dependencies is considered "big", we try to put it in a partition where +-/// most dependencies are already imported, to avoid duplicating large +-/// amounts of code. +-/// - There's special care for indirect calls in order to ensure +-/// AMDGPUResourceUsageAnalysis can work correctly. +-/// +-/// This file also includes a more elaborate logging system to enable +-/// users to easily generate logs that (if desired) do not include any value +-/// names, in order to not leak information about the source file. +-/// Such logs are very helpful to understand and fix potential issues with +-/// module splitting. +- +-#include "AMDGPUSplitModule.h" +-#include "AMDGPUTargetMachine.h" +-#include "Utils/AMDGPUBaseInfo.h" +-#include "llvm/ADT/DenseMap.h" +-#include "llvm/ADT/SmallVector.h" +-#include "llvm/ADT/StringExtras.h" +-#include "llvm/ADT/StringRef.h" +-#include "llvm/Analysis/CallGraph.h" +-#include "llvm/Analysis/TargetTransformInfo.h" +-#include "llvm/IR/Function.h" +-#include "llvm/IR/Instruction.h" +-#include "llvm/IR/Module.h" +-#include "llvm/IR/User.h" +-#include "llvm/IR/Value.h" +-#include "llvm/Support/Casting.h" +-#include "llvm/Support/Debug.h" +-#include "llvm/Support/FileSystem.h" +-#include "llvm/Support/Path.h" +-#include "llvm/Support/Process.h" +-#include "llvm/Support/SHA256.h" +-#include "llvm/Support/Threading.h" +-#include "llvm/Support/raw_ostream.h" +-#include "llvm/Transforms/Utils/Cloning.h" +-#include +-#include +-#include +-#include +-#include +-#include +- +-using namespace llvm; +- +-#define DEBUG_TYPE "amdgpu-split-module" +- +-namespace { +- +-static cl::opt LargeKernelFactor( +- "amdgpu-module-splitting-large-kernel-threshold", cl::init(2.0f), +- cl::Hidden, +- cl::desc( +- "consider a kernel as large and needing special treatment when it " +- "exceeds the average cost of a partition by this factor; e;g. 2.0 " +- "means if the kernel and its dependencies is 2 times bigger than " +- "an average partition; 0 disables large kernels handling entirely")); +- +-static cl::opt LargeKernelOverlapForMerge( +- "amdgpu-module-splitting-large-kernel-merge-overlap", cl::init(0.8f), +- cl::Hidden, +- cl::desc("defines how much overlap between two large kernel's dependencies " +- "is needed to put them in the same partition")); +- +-static cl::opt NoExternalizeGlobals( +- "amdgpu-module-splitting-no-externalize-globals", cl::Hidden, +- cl::desc("disables externalization of global variable with local linkage; " +- "may cause globals to be duplicated which increases binary size")); +- +-static cl::opt +- LogDirOpt("amdgpu-module-splitting-log-dir", cl::Hidden, +- cl::desc("output directory for AMDGPU module splitting logs")); +- +-static cl::opt +- LogPrivate("amdgpu-module-splitting-log-private", cl::Hidden, +- cl::desc("hash value names before printing them in the AMDGPU " +- "module splitting logs")); +- +-using CostType = InstructionCost::CostType; +-using PartitionID = unsigned; +- +-static bool isEntryPoint(const Function *F) { +- return AMDGPU::isEntryFunctionCC(F->getCallingConv()); +-} +- +-static std::string getName(const Value &V) { +- static bool HideNames; +- +- static llvm::once_flag HideNameInitFlag; +- llvm::call_once(HideNameInitFlag, [&]() { +- if (LogPrivate.getNumOccurrences()) +- HideNames = LogPrivate; +- else { +- const auto EV = sys::Process::GetEnv("AMD_SPLIT_MODULE_LOG_PRIVATE"); +- HideNames = (EV.value_or("0") != "0"); +- } +- }); +- +- if (!HideNames) +- return V.getName().str(); +- return toHex(SHA256::hash(arrayRefFromStringRef(V.getName())), +- /*LowerCase=*/true); +-} +- +-/// Main logging helper. +-/// +-/// Logging can be configured by the following environment variable. +-/// AMD_SPLIT_MODULE_LOG_DIR= +-/// If set, uses as the directory to write logfiles to +-/// each time module splitting is used. +-/// AMD_SPLIT_MODULE_LOG_PRIVATE +-/// If set to anything other than zero, all names are hidden. +-/// +-/// Both environment variables have corresponding CL options which +-/// takes priority over them. +-/// +-/// Any output printed to the log files is also printed to dbgs() when -debug is +-/// used and LLVM_DEBUG is defined. +-/// +-/// This approach has a small disadvantage over LLVM_DEBUG though: logging logic +-/// cannot be removed from the code (by building without debug). This probably +-/// has a small performance cost because if some computation/formatting is +-/// needed for logging purpose, it may be done everytime only to be ignored +-/// by the logger. +-/// +-/// As this pass only runs once and is not doing anything computationally +-/// expensive, this is likely a reasonable trade-off. +-/// +-/// If some computation should really be avoided when unused, users of the class +-/// can check whether any logging will occur by using the bool operator. +-/// +-/// \code +-/// if (SML) { +-/// // Executes only if logging to a file or if -debug is available and +-/// used. +-/// } +-/// \endcode +-class SplitModuleLogger { +-public: +- SplitModuleLogger(const Module &M) { +- std::string LogDir = LogDirOpt; +- if (LogDir.empty()) +- LogDir = sys::Process::GetEnv("AMD_SPLIT_MODULE_LOG_DIR").value_or(""); +- +- // No log dir specified means we don't need to log to a file. +- // We may still log to dbgs(), though. +- if (LogDir.empty()) +- return; +- +- // If a log directory is specified, create a new file with a unique name in +- // that directory. +- int Fd; +- SmallString<0> PathTemplate; +- SmallString<0> RealPath; +- sys::path::append(PathTemplate, LogDir, "Module-%%-%%-%%-%%-%%-%%-%%.txt"); +- if (auto Err = +- sys::fs::createUniqueFile(PathTemplate.str(), Fd, RealPath)) { +- report_fatal_error("Failed to create log file at '" + Twine(LogDir) + +- "': " + Err.message(), +- /*CrashDiag=*/false); +- } +- +- FileOS = std::make_unique(Fd, /*shouldClose=*/true); +- } +- +- bool hasLogFile() const { return FileOS != nullptr; } +- +- raw_ostream &logfile() { +- assert(FileOS && "no logfile!"); +- return *FileOS; +- } +- +- /// \returns true if this SML will log anything either to a file or dbgs(). +- /// Can be used to avoid expensive computations that are ignored when logging +- /// is disabled. +- operator bool() const { +- return hasLogFile() || (DebugFlag && isCurrentDebugType(DEBUG_TYPE)); +- } +- +-private: +- std::unique_ptr FileOS; +-}; +- +-template +-static SplitModuleLogger &operator<<(SplitModuleLogger &SML, const Ty &Val) { +- static_assert( +- !std::is_same_v, +- "do not print values to logs directly, use handleName instead!"); +- LLVM_DEBUG(dbgs() << Val); +- if (SML.hasLogFile()) +- SML.logfile() << Val; +- return SML; +-} +- +-/// Calculate the cost of each function in \p M +-/// \param SML Log Helper +-/// \param TM TargetMachine instance used to retrieve TargetTransformInfo. +-/// \param M Module to analyze. +-/// \param CostMap[out] Resulting Function -> Cost map. +-/// \return The module's total cost. +-static CostType +-calculateFunctionCosts(SplitModuleLogger &SML, const AMDGPUTargetMachine &TM, +- Module &M, +- DenseMap &CostMap) { +- CostType ModuleCost = 0; +- CostType KernelCost = 0; +- +- for (auto &Fn : M) { +- if (Fn.isDeclaration()) +- continue; +- +- CostType FnCost = 0; +- TargetTransformInfo TTI = TM.getTargetTransformInfo(Fn); +- +- for (const auto &BB : Fn) { +- for (const auto &I : BB) { +- auto Cost = +- TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); +- assert(Cost != InstructionCost::getMax()); +- // Assume expensive if we can't tell the cost of an instruction. +- CostType CostVal = +- Cost.getValue().value_or(TargetTransformInfo::TCC_Expensive); +- assert((FnCost + CostVal) >= FnCost && "Overflow!"); +- FnCost += CostVal; - } -- case Decl::ClassTemplate: { -- auto *A = cast(TempArg); -- auto *CT = ClassTemplateDecl::Create(S.Context, A->getDeclContext(), -- SourceLocation(), A->getDeclName(), -- TPL, A->getTemplatedDecl()); -- CT->setPreviousDecl(A); -- TD = CT; -- break; +- } +- +- assert(FnCost != 0); +- +- CostMap[&Fn] = FnCost; +- assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!"); +- ModuleCost += FnCost; +- +- if (isEntryPoint(&Fn)) +- KernelCost += FnCost; +- } +- +- CostType FnCost = (ModuleCost - KernelCost); +- SML << "=> Total Module Cost: " << ModuleCost << '\n' +- << " => KernelCost: " << KernelCost << " (" +- << format("%0.2f", (float(KernelCost) / ModuleCost) * 100) << "%)\n" +- << " => FnsCost: " << FnCost << " (" +- << format("%0.2f", (float(FnCost) / ModuleCost) * 100) << "%)\n"; +- +- return ModuleCost; +-} +- +-static bool canBeIndirectlyCalled(const Function &F) { +- if (F.isDeclaration() || isEntryPoint(&F)) +- return false; +- return !F.hasLocalLinkage() || +- F.hasAddressTaken(/*PutOffender=*/nullptr, +- /*IgnoreCallbackUses=*/false, +- /*IgnoreAssumeLikeCalls=*/true, +- /*IgnoreLLVMUsed=*/true, +- /*IgnoreARCAttachedCall=*/false, +- /*IgnoreCastedDirectCall=*/true); +-} +- +-/// When a kernel or any of its callees performs an indirect call, this function +-/// takes over \ref addAllDependencies and adds all potentially callable +-/// functions to \p Fns so they can be counted as dependencies of the kernel. +-/// +-/// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the +-/// presence of an indirect call, the function's resource usage is the same as +-/// the most expensive function in the module. +-/// \param M The module. +-/// \param Fns[out] Resulting list of functions. +-static void addAllIndirectCallDependencies(const Module &M, +- DenseSet &Fns) { +- for (const auto &Fn : M) { +- if (canBeIndirectlyCalled(Fn)) +- Fns.insert(&Fn); +- } +-} +- +-/// Adds the functions that \p Fn may call to \p Fns, then recurses into each +-/// callee until all reachable functions have been gathered. +-/// +-/// \param SML Log Helper +-/// \param CG Call graph for \p Fn's module. +-/// \param Fn Current function to look at. +-/// \param Fns[out] Resulting list of functions. +-/// \param HadIndirectCall[out] Set to true if an indirect call was seen at some +-/// point, either in \p Fn or in one of the function it calls. When that +-/// happens, we fall back to adding all callable functions inside \p Fn's module +-/// to \p Fns. +-static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG, +- const Function &Fn, +- DenseSet &Fns, +- bool &HadIndirectCall) { +- assert(!Fn.isDeclaration()); +- +- const Module &M = *Fn.getParent(); +- SmallVector WorkList({&Fn}); +- while (!WorkList.empty()) { +- const auto &CurFn = *WorkList.pop_back_val(); +- assert(!CurFn.isDeclaration()); +- +- // Scan for an indirect call. If such a call is found, we have to +- // conservatively assume this can call all non-entrypoint functions in the +- // module. +- +- for (auto &CGEntry : *CG[&CurFn]) { +- auto *CGNode = CGEntry.second; +- auto *Callee = CGNode->getFunction(); +- if (!Callee) { +- // Functions have an edge towards CallsExternalNode if they're external +- // declarations, or if they do an indirect call. As we only process +- // definitions here, we know this means the function has an indirect +- // call. We then have to conservatively assume this can call all +- // non-entrypoint functions in the module. +- if (CGNode != CG.getCallsExternalNode()) +- continue; // this is another function-less node we don't care about. +- +- SML << "Indirect call detected in " << getName(CurFn) +- << " - treating all non-entrypoint functions as " +- "potential dependencies\n"; +- +- // TODO: Print an ORE as well ? +- addAllIndirectCallDependencies(M, Fns); +- HadIndirectCall = true; +- return; - } -- default: -- llvm_unreachable("Unexpected Template Kind"); -+ TemplateParameterList *As = TempArg->getTemplateParameters(); -+ if (DefaultArguments.size() != 0) { -+ assert(DefaultArguments.size() <= As->size()); -+ SmallVector Params(As->size()); -+ for (unsigned I = 0; I < DefaultArguments.size(); ++I) -+ Params[I] = getTemplateParameterWithDefault(S, As->getParam(I), -+ DefaultArguments[I]); -+ for (unsigned I = DefaultArguments.size(); I < As->size(); ++I) -+ Params[I] = As->getParam(I); -+ // FIXME: We could unique these, and also the parameters, but we don't -+ // expect programs to contain a large enough amount of these deductions -+ // for that to be worthwhile. -+ auto *TPL = TemplateParameterList::Create( -+ S.Context, SourceLocation(), SourceLocation(), Params, -+ SourceLocation(), As->getRequiresClause()); -+ NewDeduced = DeducedTemplateArgument( -+ TemplateName(TemplateTemplateParmDecl::Create( -+ S.Context, TempArg->getDeclContext(), SourceLocation(), -+ TempArg->getDepth(), TempArg->getPosition(), -+ TempArg->isParameterPack(), TempArg->getIdentifier(), -+ TempArg->wasDeclaredWithTypename(), TPL))); - } -- TD->setImplicit(true); -- NewDeduced = DeducedTemplateArgument(TemplateName(TD)); - } - - DeducedTemplateArgument Result = checkDeducedTemplateArguments(S.Context, -diff -ruN --strip-trailing-cr a/clang/test/CXX/temp/temp.decls/temp.alias/p2.cpp b/clang/test/CXX/temp/temp.decls/temp.alias/p2.cpp ---- a/clang/test/CXX/temp/temp.decls/temp.alias/p2.cpp -+++ b/clang/test/CXX/temp/temp.decls/temp.alias/p2.cpp -@@ -28,14 +28,13 @@ - { /* ... */ } - - template class TT> -- void f(TT); -+ void f(TT); // expected-note {{candidate template ignored}} - - template class TT> - void g(TT>); +- +- if (Callee->isDeclaration()) +- continue; +- +- auto [It, Inserted] = Fns.insert(Callee); +- if (Inserted) +- WorkList.push_back(Callee); +- } +- } +-} +- +-/// Contains information about a kernel and its dependencies. +-struct KernelWithDependencies { +- KernelWithDependencies(SplitModuleLogger &SML, CallGraph &CG, +- const DenseMap &FnCosts, +- const Function *Fn) +- : Fn(Fn) { +- addAllDependencies(SML, CG, *Fn, Dependencies, HasIndirectCall); +- TotalCost = FnCosts.at(Fn); +- for (const auto *Dep : Dependencies) { +- TotalCost += FnCosts.at(Dep); +- +- // We cannot duplicate functions with external linkage, or functions that +- // may be overriden at runtime. +- HasNonDuplicatableDependecy |= +- (Dep->hasExternalLinkage() || !Dep->isDefinitionExact()); +- } +- } +- +- const Function *Fn = nullptr; +- DenseSet Dependencies; +- /// Whether \p Fn or any of its \ref Dependencies contains an indirect call. +- bool HasIndirectCall = false; +- /// Whether any of \p Fn's dependencies cannot be duplicated. +- bool HasNonDuplicatableDependecy = false; +- +- CostType TotalCost = 0; +- +- /// \returns true if this kernel and its dependencies can be considered large +- /// according to \p Threshold. +- bool isLarge(CostType Threshold) const { +- return TotalCost > Threshold && !Dependencies.empty(); +- } +-}; +- +-/// Calculates how much overlap there is between \p A and \p B. +-/// \return A number between 0.0 and 1.0, where 1.0 means A == B and 0.0 means A +-/// and B have no shared elements. Kernels do not count in overlap calculation. +-static float calculateOverlap(const DenseSet &A, +- const DenseSet &B) { +- DenseSet Total; +- for (const auto *F : A) { +- if (!isEntryPoint(F)) +- Total.insert(F); +- } +- +- if (Total.empty()) +- return 0.0f; +- +- unsigned NumCommon = 0; +- for (const auto *F : B) { +- if (isEntryPoint(F)) +- continue; +- +- auto [It, Inserted] = Total.insert(F); +- if (!Inserted) +- ++NumCommon; +- } +- +- return static_cast(NumCommon) / Total.size(); +-} +- +-/// Performs all of the partitioning work on \p M. +-/// \param SML Log Helper +-/// \param M Module to partition. +-/// \param NumParts Number of partitions to create. +-/// \param ModuleCost Total cost of all functions in \p M. +-/// \param FnCosts Map of Function -> Cost +-/// \param WorkList Kernels and their dependencies to process in order. +-/// \returns The created partitions (a vector of size \p NumParts ) +-static std::vector> +-doPartitioning(SplitModuleLogger &SML, Module &M, unsigned NumParts, +- CostType ModuleCost, +- const DenseMap &FnCosts, +- const SmallVector &WorkList) { +- +- SML << "\n--Partitioning Starts--\n"; +- +- // Calculate a "large kernel threshold". When more than one kernel's total +- // import cost exceeds this value, we will try to merge it with other, +- // similarly large kernels. +- // +- // e.g. let two kernels X and Y have a import cost of ~10% of the module, we +- // assign X to a partition as usual, but when we get to Y, we check if it's +- // worth also putting it in Y's partition. +- const CostType LargeKernelThreshold = +- LargeKernelFactor ? ((ModuleCost / NumParts) * LargeKernelFactor) +- : std::numeric_limits::max(); +- +- std::vector> Partitions; +- Partitions.resize(NumParts); +- +- // Assign a partition to each kernel, and try to keep the partitions more or +- // less balanced. We do that through a priority queue sorted in reverse, so we +- // can always look at the partition with the least content. +- // +- // There are some cases where we will be deliberately unbalanced though. +- // - Large kernels: we try to merge with existing partitions to reduce code +- // duplication. +- // - Kernels with indirect or external calls always go in the first partition +- // (P0). +- auto ComparePartitions = [](const std::pair &a, +- const std::pair &b) { +- // When two partitions have the same cost, assign to the one with the +- // biggest ID first. This allows us to put things in P0 last, because P0 may +- // have other stuff added later. +- if (a.second == b.second) +- return a.first < b.first; +- return a.second > b.second; +- }; +- +- // We can't use priority_queue here because we need to be able to access any +- // element. This makes this a bit inefficient as we need to sort it again +- // everytime we change it, but it's a very small array anyway (likely under 64 +- // partitions) so it's a cheap operation. +- std::vector> BalancingQueue; +- for (unsigned I = 0; I < NumParts; ++I) +- BalancingQueue.push_back(std::make_pair(I, 0)); +- +- // Helper function to handle assigning a kernel to a partition. This takes +- // care of updating the balancing queue. +- const auto AssignToPartition = [&](PartitionID PID, +- const KernelWithDependencies &KWD) { +- auto &FnsInPart = Partitions[PID]; +- FnsInPart.insert(KWD.Fn); +- FnsInPart.insert(KWD.Dependencies.begin(), KWD.Dependencies.end()); +- +- SML << "assign " << getName(*KWD.Fn) << " to P" << PID << "\n -> "; +- if (!KWD.Dependencies.empty()) { +- SML << KWD.Dependencies.size() << " dependencies added\n"; +- }; +- +- // Update the balancing queue. we scan backwards because in the common case +- // the partition is at the end. +- for (auto &[QueuePID, Cost] : reverse(BalancingQueue)) { +- if (QueuePID == PID) { +- CostType NewCost = 0; +- for (auto *Fn : Partitions[PID]) +- NewCost += FnCosts.at(Fn); +- +- SML << "[Updating P" << PID << " Cost]:" << Cost << " -> " << NewCost; +- if (Cost) { +- SML << " (" << unsigned(((float(NewCost) / Cost) - 1) * 100) +- << "% increase)"; +- } +- SML << '\n'; +- +- Cost = NewCost; +- } +- } +- +- sort(BalancingQueue, ComparePartitions); +- }; +- +- for (auto &CurKernel : WorkList) { +- // When a kernel has indirect calls, it must stay in the first partition +- // alongside every reachable non-entry function. This is a nightmare case +- // for splitting as it severely limits what we can do. +- if (CurKernel.HasIndirectCall) { +- SML << "Kernel with indirect call(s): " << getName(*CurKernel.Fn) +- << " defaulting to P0\n"; +- AssignToPartition(0, CurKernel); +- continue; +- } +- +- // When a kernel has non duplicatable dependencies, we have to keep it in +- // the first partition as well. This is a conservative approach, a +- // finer-grained approach could keep track of which dependencies are +- // non-duplicatable exactly and just make sure they're grouped together. +- if (CurKernel.HasNonDuplicatableDependecy) { +- SML << "Kernel with externally visible dependency " +- << getName(*CurKernel.Fn) << " defaulting to P0\n"; +- AssignToPartition(0, CurKernel); +- continue; +- } +- +- // Be smart with large kernels to avoid duplicating their dependencies. +- if (CurKernel.isLarge(LargeKernelThreshold)) { +- assert(LargeKernelOverlapForMerge >= 0.0f && +- LargeKernelOverlapForMerge <= 1.0f); +- SML << "Large Kernel: " << getName(*CurKernel.Fn) +- << " - looking for partition with at least " +- << format("%0.2f", LargeKernelOverlapForMerge * 100) << "% overlap\n"; +- +- bool Assigned = false; +- for (const auto &[PID, Fns] : enumerate(Partitions)) { +- float Overlap = calculateOverlap(CurKernel.Dependencies, Fns); +- SML << " => " << format("%0.2f", Overlap * 100) << "% overlap with P" +- << PID << '\n'; +- if (Overlap > LargeKernelOverlapForMerge) { +- SML << " selecting P" << PID << '\n'; +- AssignToPartition(PID, CurKernel); +- Assigned = true; +- } +- } +- +- if (Assigned) +- continue; +- } +- +- // Normal "load-balancing", assign to partition with least pressure. +- auto [PID, CurCost] = BalancingQueue.back(); +- AssignToPartition(PID, CurKernel); +- } +- +- // Work is mostly done now, verify the partioning and add all functions we may +- // have missed (= unreachable, or we don't understand how they're reached) to +- // P0. +- DenseSet AllFunctions; +- for (const auto &[Idx, Part] : enumerate(Partitions)) { +- CostType Cost = 0; +- for (auto *Fn : Part) { +- // external linkage functions should exclusively be in the first partition +- // at this stage. In theory, we should only ever see external linkage +- // functions here if they're kernels, or if they've been added due to a +- // kernel using indirect calls somewhere in its CallGraph. +- assert(Idx == 0 || (!Fn->hasExternalLinkage() || isEntryPoint(Fn))); +- Cost += FnCosts.at(Fn); +- } +- SML << "P" << Idx << " has a total cost of " << Cost << " (" +- << format("%0.2f", (float(Cost) / ModuleCost) * 100) +- << "% of source module)\n"; +- AllFunctions.insert(Part.begin(), Part.end()); +- } +- +- // Add missed functions to P0. This will take care of adding things like +- // external functions with no callers in the module to P0. This should be +- // fairly rare as AMDGPU internalizes everything in most cases, so unused +- // internal functions would get removed. +- for (auto &Fn : M) { +- if (!Fn.isDeclaration() && !AllFunctions.contains(&Fn)) { +- SML << getName(Fn) << " has no partition assigned, defaulting to P0\n"; +- Partitions[0].insert(&Fn); +- } +- } +- +- SML << "--Partitioning Done--\n\n"; +- +- return Partitions; +-} +- +-static void externalize(GlobalValue &GV) { +- if (GV.hasLocalLinkage()) { +- GV.setLinkage(GlobalValue::ExternalLinkage); +- GV.setVisibility(GlobalValue::HiddenVisibility); +- } +- +- // Unnamed entities must be named consistently between modules. setName will +- // give a distinct name to each such entity. +- if (!GV.hasName()) +- GV.setName("__llvmsplit_unnamed"); +-} +-} // end anonymous namespace +- +-void llvm::splitAMDGPUModule( +- const AMDGPUTargetMachine &TM, Module &M, unsigned N, +- function_ref MPart)> ModuleCallback) { +- +- SplitModuleLogger SML(M); +- +- CallGraph CG(M); +- +- // Externalize functions whose address are taken. +- // +- // This is needed because partitioning is purely based on calls, but sometimes +- // a kernel/function may just look at the address of another local function +- // and not do anything (no calls). After partitioning, that local function may +- // end up in a different module (so it's just a declaration in the module +- // where its address is taken), which emits a "undefined hidden symbol" linker +- // error. +- // +- // Additionally, it guides partitioning to not duplicate this function if it's +- // called directly at some point. +- for (auto &Fn : M) { +- if (Fn.hasAddressTaken()) { +- if (Fn.hasLocalLinkage()) { +- SML << "[externalize] " << Fn.getName() +- << " because its address is taken\n"; +- } +- externalize(Fn); +- } +- } +- +- // Externalize local GVs, which avoids duplicating their initializers, which +- // in turns helps keep code size in check. +- if (!NoExternalizeGlobals) { +- for (auto &GV : M.globals()) { +- if (GV.hasLocalLinkage()) +- SML << "[externalize] GV " << GV.getName() << '\n'; +- externalize(GV); +- } +- } +- +- // Start by calculating the cost of every function in the module, as well as +- // the module's overall cost. +- DenseMap FnCosts; +- const CostType ModuleCost = calculateFunctionCosts(SML, TM, M, FnCosts); +- +- // Gather every kernel into a WorkList, then sort it by descending total cost +- // of the kernel so the biggest kernels are seen first. +- SmallVector WorkList; +- for (auto &Fn : M) { +- if (isEntryPoint(&Fn) && !Fn.isDeclaration()) +- WorkList.emplace_back(SML, CG, FnCosts, &Fn); +- } +- sort(WorkList, [&](auto &A, auto &B) { +- // Sort by total cost, and if the total cost is identical, sort +- // alphabetically. +- if (A.TotalCost == B.TotalCost) +- return A.Fn->getName() < B.Fn->getName(); +- return A.TotalCost > B.TotalCost; +- }); +- +- if (SML) { +- SML << "Worklist\n"; +- for (const auto &KWD : WorkList) { +- SML << "[Kernel] " << getName(*KWD.Fn) << " (totalCost:" << KWD.TotalCost +- << " indirect:" << KWD.HasIndirectCall +- << " hasNonDuplicatableDep:" << KWD.HasNonDuplicatableDependecy +- << ")\n"; +- for (const auto *Dep : KWD.Dependencies) +- SML << " [Dep] " << getName(*Dep) << '\n'; +- } +- } +- +- // This performs all of the partitioning work. +- auto Partitions = doPartitioning(SML, M, N, ModuleCost, FnCosts, WorkList); +- assert(Partitions.size() == N); +- +- // If we didn't externalize GVs, then local GVs need to be conservatively +- // imported into every module (including their initializers), and then cleaned +- // up afterwards. +- const auto NeedsConservativeImport = [&](const GlobalValue *GV) { +- // We conservatively import private/internal GVs into every module and clean +- // them up afterwards. +- const auto *Var = dyn_cast(GV); +- return Var && Var->hasLocalLinkage(); +- }; +- +- SML << "Creating " << N << " modules...\n"; +- unsigned TotalFnImpls = 0; +- for (unsigned I = 0; I < N; ++I) { +- const auto &FnsInPart = Partitions[I]; +- +- ValueToValueMapTy VMap; +- std::unique_ptr MPart( +- CloneModule(M, VMap, [&](const GlobalValue *GV) { +- // Functions go in their assigned partition. +- if (const auto *Fn = dyn_cast(GV)) { +-// Check we don't import an external linkage function in any +-// partition other than P0. +-#ifndef NDEBUG +- if (Fn->hasExternalLinkage() && !isEntryPoint(Fn)) { +- assert((I == 0) == FnsInPart.contains(Fn)); +- } +-#endif +- return FnsInPart.contains(Fn); +- } +- +- if (NeedsConservativeImport(GV)) +- return true; +- +- // Everything else goes in the first partition. +- return I == 0; +- })); +- +- // Clean-up conservatively imported GVs without any users. +- for (auto &GV : make_early_inc_range(MPart->globals())) { +- if (NeedsConservativeImport(&GV) && GV.use_empty()) +- GV.eraseFromParent(); +- } +- +- unsigned NumAllFns = 0, NumKernels = 0; +- for (auto &Cur : *MPart) { +- if (!Cur.isDeclaration()) { +- ++NumAllFns; +- if (isEntryPoint(&Cur)) +- ++NumKernels; +- } +- } +- TotalFnImpls += NumAllFns; +- SML << " - Module " << I << " with " << NumAllFns << " functions (" +- << NumKernels << " kernels)\n"; +- ModuleCallback(std::move(MPart)); +- } +- +- SML << TotalFnImpls << " function definitions across all modules (" +- << format("%0.2f", (float(TotalFnImpls) / FnCosts.size()) * 100) +- << "% of original module)\n"; +-} +diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.h b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.h +--- a/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.h ++++ b/llvm/lib/Target/AMDGPU/AMDGPUSplitModule.h +@@ -1,30 +0,0 @@ +-//===- AMDGPUSplitModule.h -------------------------------------*- C++ -*-===// +-// +-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +-// See https://llvm.org/LICENSE.txt for license information. +-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +-// +-//===----------------------------------------------------------------------===// +-// +-//===----------------------------------------------------------------------===// +- +-#ifndef LLVM_TARGET_AMDGPUSPLITMODULE_H +-#define LLVM_TARGET_AMDGPUSPLITMODULE_H +- +-#include "llvm/ADT/STLFunctionalExtras.h" +-#include +- +-namespace llvm { +- +-class Module; +-class AMDGPUTargetMachine; +- +-/// Splits the module M into N linkable partitions. The function ModuleCallback +-/// is called N times passing each individual partition as the MPart argument. +-void splitAMDGPUModule( +- const AMDGPUTargetMachine &TM, Module &M, unsigned N, +- function_ref MPart)> ModuleCallback); +- +-} // end namespace llvm +- +-#endif // LLVM_TARGET_AMDGPUSPLITMODULE_H +diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp ++++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +@@ -21,7 +21,6 @@ + #include "AMDGPUIGroupLP.h" + #include "AMDGPUMacroFusion.h" + #include "AMDGPURegBankSelect.h" +-#include "AMDGPUSplitModule.h" + #include "AMDGPUTargetObjectFile.h" + #include "AMDGPUTargetTransformInfo.h" + #include "AMDGPUUnifyDivergentExitNodes.h" +@@ -816,13 +815,6 @@ + return AMDGPUAS::FLAT_ADDRESS; + } - int h() { -- f(v); // OK: TT = vector, Alloc is used as the default argument for the -- // second parameter. -+ f(v); // expected-error {{no matching function for call to 'f'}} - g(v); // OK: TT = vector - } +-bool AMDGPUTargetMachine::splitModule( +- Module &M, unsigned NumParts, +- function_ref MPart)> ModuleCallback) const { +- splitAMDGPUModule(*this, M, NumParts, ModuleCallback); +- return true; +-} +- + //===----------------------------------------------------------------------===// + // GCN Target Machine (SI+) + //===----------------------------------------------------------------------===// +diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h +--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h ++++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.h +@@ -73,10 +73,6 @@ + getPredicatedAddrSpace(const Value *V) const override; -diff -ruN --strip-trailing-cr a/clang/test/SemaTemplate/cwg2398.cpp b/clang/test/SemaTemplate/cwg2398.cpp ---- a/clang/test/SemaTemplate/cwg2398.cpp -+++ b/clang/test/SemaTemplate/cwg2398.cpp -@@ -65,10 +65,13 @@ - template struct B; + unsigned getAddressSpaceForPseudoSourceKind(unsigned Kind) const override; +- +- bool splitModule(Module &M, unsigned NumParts, +- function_ref MPart)> +- ModuleCallback) const override; + }; - template