Skip to content

Commit

Permalink
Add torchao (#1182)
Browse files Browse the repository at this point in the history
* init

* update install utils

* update

* update libs

* update torchao pin

* fix ci test

* add python et install to ci

* fix ci errors

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes
  • Loading branch information
metascroy authored Sep 27, 2024
1 parent e4b36f9 commit 1980a69
Show file tree
Hide file tree
Showing 11 changed files with 451 additions and 123 deletions.
343 changes: 232 additions & 111 deletions .github/workflows/pull.yml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ __pycache__/
# Build directories
build/android/*
et-build/*
torchao-build/*
runner-et/cmake-out/*
runner-aoti/cmake-out/*
cmake-out/
Expand Down
69 changes: 69 additions & 0 deletions docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,75 @@ python3 torchchat.py export llama3 --quantize '{"embedding": {"bitwidth": 4, "gr
python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my name is"
```

## Experimental TorchAO lowbit kernels

### Use
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
It takes arguments bitwidth (2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
Roughly speaking, {bitwidth: 4, groupsize: 256, has_weight_zeros: false} is similar to GGML's Q40 quantization scheme.

You should expect high performance on ARM CPU if bitwidth is 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.

### Setup
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.

From the torchchat root directory, run
```
sh torchchat/utils/scripts/build_torchao_ops.sh
```

This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.

Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao when running the scripts the build the runners.

```
sh torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
```

```
sh torchchat/utils/scripts/build_native.sh et link_torchao_ops
```

Note before running `sh torchchat/utils/scripts/build_native.sh et link_torchao_ops`, you must first install executorch with `sh torchchat/utils/scripts/install_et.sh` if you have not done so already.

### Examples

Below we show how to use the new kernels. Except for ExecuTorch, you can specify the number of threads used by setting OMP_NUM_THREADS (as is the case with PyTorch in general). Doing so is optional and a default number of threads will be chosen automatically if you do not specify.

#### Eager mode
```
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
```

#### torch.compile
```
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
```

#### AOTI
```
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
```

If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:

```
OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
```

#### ExecuTorch
```
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
```

Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.

```
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
```

## Quantization Profiles

Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json`
Expand Down
1 change: 1 addition & 0 deletions install/.pins/torchao-pin.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
63cb7a9857654784f726fec75c0dc36167094d8a
4 changes: 4 additions & 0 deletions runner/aoti.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ if(Torch_FOUND)
target_link_libraries(aoti_run "${TORCH_LIBRARIES}" m)
set_property(TARGET aoti_run PROPERTY CXX_STANDARD 17)
endif()

if (LINK_TORCHAO_OPS)
target_link_libraries(aoti_run "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/liblinear_a8wxdq_ATEN${CMAKE_SHARED_LIBRARY_SUFFIX}")
endif()
8 changes: 8 additions & 0 deletions runner/et.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ if(executorch_FOUND)
target_link_libraries(et_run PRIVATE log)
endif()

if(LINK_TORCHAO_OPS)
target_link_libraries(et_run PRIVATE "$<LINK_LIBRARY:WHOLE_ARCHIVE,${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/liblinear_a8wxdq_EXECUTORCH.a>")
target_link_libraries(et_run PRIVATE
"${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_kernels_aarch64.a"
"${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_ops_linear_EXECUTORCH.a"
)
endif()

else()
MESSAGE(WARNING "ExecuTorch package not found")
endif()
47 changes: 43 additions & 4 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,19 @@ def quantize_model(
precision = get_precision()

try:
# Easier to ask forgiveness than permission
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], device=device, precision=precision
)
if quantizer == "linear:a8wxdq":
quant_handler = ao_quantizer_class_dict[quantizer](
device=device,
precision=precision,
bitwidth=q_kwargs.get("bitwidth", 4),
groupsize=q_kwargs.get("groupsize", 128),
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
)
else:
# Easier to ask forgiveness than permission
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], device=device, precision=precision
)
except TypeError as e:
if "unexpected keyword argument 'device'" in str(e):
quant_handler = ao_quantizer_class_dict[quantizer](
Expand Down Expand Up @@ -861,3 +870,33 @@ def quantized_model(self) -> nn.Module:
"linear:int4": Int4WeightOnlyQuantizer,
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}

try:
import importlib.util
import sys
import os
torchao_build_path = f"{os.getcwd()}/torchao-build"

# Try loading quantizer
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
"torchao_experimental_quant_api",
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
)
torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec)
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer

# Try loading custom op
try:
import glob
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/liblinear_a8wxdq_ATEN.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])
except Exception as e:
print("Failed to torchao ops library with error: ", e)
print("Slow fallback kernels will be used.")

except Exception as e:
print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}")
24 changes: 22 additions & 2 deletions torchchat/utils/scripts/build_native.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ if [ $# -eq 0 ]; then
exit 1
fi

LINK_TORCHAO_OPS=OFF
while (( "$#" )); do
case "$1" in
-h|--help)
Expand All @@ -42,6 +43,11 @@ while (( "$#" )); do
TARGET="et"
shift
;;
link_torchao_ops)
echo "Linking with torchao ops..."
LINK_TORCHAO_OPS=ON
shift
;;
*)
echo "Invalid option: $1"
show_help
Expand All @@ -66,14 +72,28 @@ if [[ "$TARGET" == "et" ]]; then
echo "Make sure you run install_executorch_libs"
exit 1
fi

if [[ "$LINK_TORCHAO_OPS" == "ON" ]]; then
if [ ! -d "${TORCHCHAT_ROOT}/torchao-build" ]; then
echo "Directory ${TORCHCHAT_ROOT}/torchao-build does not exist."
echo "Make sure you run clone_torchao"
exit 1
fi

source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"
find_cmake_prefix_path
EXECUTORCH_INCLUDE_DIRS="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/include;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src"
EXECUTORCH_LIBRARIES="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libexecutorch_no_prim_ops.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libextension_threadpool.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libcpuinfo.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libpthreadpool.a"
install_torchao_executorch_ops
fi
fi
popd

# CMake commands
if [[ "$TARGET" == "et" ]]; then
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_OPS="${LINK_TORCHAO_OPS}" -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja
else
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_OPS="${LINK_TORCHAO_OPS}" -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja
fi
cmake --build ./cmake-out --target "${TARGET}"_run

Expand Down
16 changes: 16 additions & 0 deletions torchchat/utils/scripts/build_torchao_ops.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.



source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"

pushd ${TORCHCHAT_ROOT}
find_cmake_prefix_path
clone_torchao
install_torchao_aten_ops
popd
6 changes: 0 additions & 6 deletions torchchat/utils/scripts/install_et.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,4 @@ pushd ${TORCHCHAT_ROOT}
find_cmake_prefix_path
clone_executorch
install_executorch_libs $ENABLE_ET_PYBIND
install_executorch_python_libs $ENABLE_ET_PYBIND
# TODO: figure out the root cause of 'AttributeError: module 'evaluate'
# has no attribute 'utils'' error from evaluate CI jobs and remove
# `import lm_eval` from torchchat.py since it requires a specific version
# of numpy.
pip install numpy=='1.26.4'
popd
55 changes: 55 additions & 0 deletions torchchat/utils/scripts/install_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ install_executorch_python_libs() {
echo "Installing pybind"
bash ./install_requirements.sh --pybind xnnpack
fi

# TODO: figure out the root cause of 'AttributeError: module 'evaluate'
# has no attribute 'utils'' error from evaluate CI jobs and remove
# `import lm_eval` from torchchat.py since it requires a specific version
# of numpy.
pip install numpy=='1.26.4'

pip3 list
popd
}
Expand Down Expand Up @@ -161,3 +168,51 @@ install_executorch_libs() {
install_executorch_cpp_libs
install_executorch_python_libs $1
}

clone_torchao() {
echo "Cloning torchao to ${TORCHCHAT_ROOT}/torchao-build/src"
rm -rf ${TORCHCHAT_ROOT}/torchao-build
mkdir -p ${TORCHCHAT_ROOT}/torchao-build/src
pushd ${TORCHCHAT_ROOT}/torchao-build/src
echo $pwd

git clone https://github.com/pytorch/ao.git
cd ao
git checkout $(cat ${TORCHCHAT_ROOT}/intstall/.pins/torchao-pin.txt)

popd
}

install_torchao_aten_ops() {
echo "Building torchao custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental

CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
-DCMAKE_BUILD_TYPE="Release" \
-DTORCHAO_OP_TARGET="ATEN" \
-S . \
-B ${CMAKE_OUT_DIR} -G Ninja
cmake --build ${CMAKE_OUT_DIR} --target install --config Release

popd
}

install_torchao_executorch_ops() {
echo "Building torchao custom ops for ExecuTorch"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental

CMAKE_OUT_DIR="${TORCHCHAT_ROOT}/torchao-build/cmake-out"
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
-DCMAKE_BUILD_TYPE="Release" \
-DTORCHAO_OP_TARGET="EXECUTORCH" \
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
-S . \
-B ${CMAKE_OUT_DIR} -G Ninja
cmake --build ${CMAKE_OUT_DIR} --target install --config Release

popd
}

0 comments on commit 1980a69

Please sign in to comment.