Skip to content

Commit

Permalink
Update on "[ET-VK] Using a single GPU buffer for all tensor uniforms."
Browse files Browse the repository at this point in the history
This diff changes Tensor class to store all uniforms in a single uniform buffer.

Entities stored in uniforms ie. size, stride, numel and logical limits are now stored in a single buffer and their offsets are stored as unsigned ints in Tensor class.

Other changes includes:
Adding a new ctor for ParamsBuffer class to allow allocation with size without data ptr.

Adding an offset input to Buffer::data function.

Adding an offset parameter to BufferBindInfo ctor, so additional offset can be supplied when binding a buffer.

Differential Revision: [D65841750](https://our.internmc.facebook.com/intern/diff/D65841750/)

[ghstack-poisoned]
  • Loading branch information
trivedivivek committed Dec 5, 2024
2 parents 9db69e5 + 7265606 commit f3bc1e6
Show file tree
Hide file tree
Showing 51 changed files with 535 additions and 357 deletions.
12 changes: 9 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ option(EXECUTORCH_BUILD_EXTENSION_TENSOR "Build the Tensor extension" OFF)

option(EXECUTORCH_BUILD_EXTENSION_TRAINING "Build the training extension" OFF)

option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)

option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)

option(EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" OFF)
Expand All @@ -216,6 +214,8 @@ option(EXECUTORCH_BUILD_KERNELS_QUANTIZED "Build the quantized kernels" OFF)

option(EXECUTORCH_BUILD_DEVTOOLS "Build the ExecuTorch Developer Tools")

option(EXECUTORCH_BUILD_TESTS "Build CMake-based unit tests" OFF)

option(EXECUTORCH_NNLIB_OPT "Build Cadence backend Hifi nnlib kernel" OFF)

option(EXECUTORCH_CADENCE_CPU_RUNNER "Build Cadence backend CPU runner" OFF)
Expand Down Expand Up @@ -330,6 +330,10 @@ if(EXECUTORCH_BUILD_PTHREADPOOL)
)
endif()

if(EXECUTORCH_BUILD_TESTS)
include(CTest)
endif()

if(NOT PYTHON_EXECUTABLE)
resolve_python_executable()
endif()
Expand Down Expand Up @@ -625,7 +629,7 @@ cmake_dependent_option(
)

# Add googletest if any test targets should be built
if(EXECUTORCH_BUILD_GTESTS)
if(BUILD_TESTING)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/googletest)
endif()

Expand Down Expand Up @@ -829,5 +833,7 @@ if(EXECUTORCH_BUILD_VULKAN)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/vulkan)
endif()

include(Test.cmake)

# Print all summary
executorch_print_configuration_summary()
29 changes: 29 additions & 0 deletions Test.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# A helper CMake file to trigger C++ unit tests.
#

if(BUILD_TESTING)
# This contains the list of tests which are always built
add_subdirectory(extension/evalue_util/test)
add_subdirectory(extension/kernel_util/test)
add_subdirectory(extension/memory_allocator/test)
add_subdirectory(extension/parallel/test)
add_subdirectory(extension/pytree/test)
add_subdirectory(kernels/portable/cpu/util/test)
add_subdirectory(kernels/prim_ops/test)
add_subdirectory(kernels/test)
add_subdirectory(runtime/core/exec_aten/testing_util/test)
add_subdirectory(runtime/core/exec_aten/util/test)
add_subdirectory(runtime/core/portable_type/test)
add_subdirectory(runtime/core/test)
add_subdirectory(runtime/executor/test)
add_subdirectory(runtime/kernel/test)
add_subdirectory(runtime/platform/test)
add_subdirectory(test/utils)
endif()
91 changes: 91 additions & 0 deletions backends/cadence/build_cadence_fusionG3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -euo pipefail

unset CMAKE_PREFIX_PATH
unset XTENSA_CORE
export XTENSA_CORE=FCV_FG3GP
git submodule sync
git submodule update --init
./install_requirements.sh

rm -rf cmake-out

STEPWISE_BUILD=false

if $STEPWISE_BUILD; then
echo "Building ExecuTorch"
cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_ENABLE_EVENT_TRACER=OFF \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
-DEXECUTORCH_BUILD_CPUINFO=OFF \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_USE_DL=OFF \
-DEXECUTORCH_BUILD_CADENCE=OFF \
-DFLATC_EXECUTABLE="$(which flatc)" \
-DHAVE_FNMATCH_H=OFF \
-Bcmake-out .

echo "Building any Cadence-specific binaries on top"
cmake -DBUCK2="$BUCK" \
-DCMAKE_TOOLCHAIN_FILE=/home/zonglinpeng/ws/zonglinpeng/executorch/backends/cadence/cadence.cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_HOST_TARGETS=ON \
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
-DEXECUTORCH_BUILD_CADENCE=ON \
-DFLATC_EXECUTABLE="$(which flatc)" \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \
-DEXECUTORCH_USE_DL=OFF \
-DBUILD_EXECUTORCH_PORTABLE_OPS=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF \
-DPYTHON_EXECUTABLE=python3 \
-DEXECUTORCH_FUSION_G3_OPT=ON \
-DEXECUTORCH_BUILD_GFLAGS=ON \
-DHAVE_FNMATCH_H=OFF \
-Bcmake-out/backends/cadence \
backends/cadence
cmake --build cmake-out/backends/cadence -j8
else
echo "Building Cadence toolchain with ExecuTorch packages"
cmake_prefix_path="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags"
cmake -DBUCK2="$BUCK" \
-DCMAKE_PREFIX_PATH="${cmake_prefix_path}" \
-DHAVE_SYS_STAT_H=ON \
-DCMAKE_TOOLCHAIN_FILE=./backends/cadence/cadence.cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_BUILD_HOST_TARGETS=ON \
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
-DEXECUTORCH_BUILD_CPUINFO=OFF \
-DEXECUTORCH_BUILD_FLATC=OFF \
-DEXECUTORCH_BUILD_CADENCE=ON \
-DFLATC_EXECUTABLE="$(which flatc)" \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_ENABLE_PROGRAM_VERIFICATION=ON \
-DEXECUTORCH_USE_DL=OFF \
-DBUILD_EXECUTORCH_PORTABLE_OPS=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF \
-DPYTHON_EXECUTABLE=python3 \
-DEXECUTORCH_FUSION_G3_OPT=ON \
-DHAVE_FNMATCH_H=OFF \
-Bcmake-out
cmake --build cmake-out --target install --config Release -j8
fi

echo "Run simple model to verify cmake build"
python3 -m examples.portable.scripts.export --model_name="add"
xt-run --turbo cmake-out/executor_runner --model_path=add.pte
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
set -euo pipefail

unset CMAKE_PREFIX_PATH
unset XTENSA_CORE
export XTENSA_CORE=nxp_rt600_RI23_11_newlib
git submodule sync
git submodule update --init
./install_requirements.sh
Expand Down Expand Up @@ -53,7 +55,7 @@ if $STEPWISE_BUILD; then
-DHAVE_FNMATCH_H=OFF \
-Bcmake-out/backends/cadence \
backends/cadence
cmake --build cmake-out/backends/cadence -j16
cmake --build cmake-out/backends/cadence -j8
else
echo "Building Cadence toolchain with ExecuTorch packages"
cmake_prefix_path="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags"
Expand All @@ -79,7 +81,7 @@ else
-DEXECUTORCH_NNLIB_OPT=ON \
-DHAVE_FNMATCH_H=OFF \
-Bcmake-out
cmake --build cmake-out --target install --config Release -j16
cmake --build cmake-out --target install --config Release -j8
fi

echo "Run simple model to verify cmake build"
Expand Down
3 changes: 1 addition & 2 deletions backends/cadence/hifi/operators/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ Tensor& mean_dim_out(
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num =
torch::executor::exeget_reduced_dim_product(in, dim_list);
const size_t num = torch::executor::get_reduced_dim_product(in, dim_list);
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
op_ceil,
op_clamp,
op_conv2d,
op_cos,
op_depth_to_space,
op_dequantize,
op_div,
Expand Down Expand Up @@ -43,6 +44,7 @@
op_rsqrt,
op_select_copy,
op_sigmoid,
op_sin,
op_skip_ops,
op_slice_copy,
op_softmax,
Expand Down Expand Up @@ -71,6 +73,7 @@
op_ceil,
op_clamp,
op_conv2d,
op_cos,
op_depth_to_space,
op_dequantize,
op_div,
Expand Down Expand Up @@ -100,6 +103,7 @@
op_rsqrt,
op_select_copy,
op_sigmoid,
op_sin,
op_skip_ops,
op_slice_copy,
op_softmax,
Expand Down
56 changes: 56 additions & 0 deletions backends/qualcomm/builders/op_cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseCos, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Cos(NodeVisitor):
target = ["aten.cos.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

cos_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseCos.op_name,
)
cos_op.AddInputTensors([input_tensor_wrapper])
cos_op.AddOutputTensors([output_tensor_wrapper])

return cos_op
56 changes: 56 additions & 0 deletions backends/qualcomm/builders/op_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseSin, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class Sin(NodeVisitor):
target = ["aten.sin.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

sin_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpElementWiseSin.op_name,
)
sin_op.AddInputTensors([input_tensor_wrapper])
sin_op.AddOutputTensors([output_tensor_wrapper])

return sin_op
10 changes: 10 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class OpElementWiseCeil:
op_name = "ElementWiseCeil"


@dataclass(init=False, frozen=True)
class OpElementWiseCos:
op_name: str = "ElementWiseCos"


@dataclass(init=False, frozen=True)
class OpElementWiseDivide:
op_name: str = "ElementWiseDivide"
Expand Down Expand Up @@ -113,6 +118,11 @@ class OpElementWiseRsqrt:
op_name: str = "ElementWiseRsqrt"


@dataclass(init=False, frozen=True)
class OpElementWiseSin:
op_name: str = "ElementWiseSin"


@dataclass(init=False, frozen=True)
class OpElementWiseSubtract:
op_name = "ElementWiseSubtract"
Expand Down
10 changes: 10 additions & 0 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.cos.default])
def annotate_cos(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.sin.default])
def annotate_sin(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)


@register_annotator([torch.ops.aten.tanh.default])
def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
Expand Down
Loading

0 comments on commit f3bc1e6

Please sign in to comment.