From 76c1cd24db69fed31c9857d5dc3a958ceed7e066 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 23 Oct 2024 10:36:31 -0700 Subject: [PATCH] bump torchao pin (#1318) * bump torchao pin * update pin * update pin * merge conflict --- .github/workflows/pull.yml | 33 ++++---------------- docs/quantization.md | 23 +++++++++----- install/.pins/torchao-pin.txt | 2 +- torchchat/utils/quantize.py | 38 ++++++++++++++++++------ torchchat/utils/scripts/install_utils.sh | 3 +- 5 files changed, 52 insertions(+), 47 deletions(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 55fe8f11d..14b8c0712 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -1092,32 +1092,11 @@ jobs: id: install-torchao-ops run: | bash torchchat/utils/scripts/build_torchao_ops.sh - - name: Set git shas - id: setup-hash - run: | - export TORCHCHAT_ROOT=${PWD} - echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV" - - name: Load or install ET - id: install-et - uses: actions/cache@v4 - with: - path: | - ./et-build - ./torchchat/utils/scripts/install_et.sh - key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }} - - if: ${{ steps.install-et.outputs.cache-hit != 'true' }} - continue-on-error: true + - name: Install ET run: | echo "Installing ExecuTorch" + export TORCHCHAT_ROOT=${PWD} bash torchchat/utils/scripts/install_et.sh - - name: Install ExecuTorch python - run: | - echo "Install ExecuTorch python" - export TORCHCHAT_ROOT=$PWD - export ET_BUILD_DIR="et-build" - ENABLE_ET_PYBIND="${1:-true}" - source "torchchat/utils/scripts/install_utils.sh" - install_executorch_python_libs $ENABLE_ET_PYBIND - name: Install runner run: | echo "Installing runner" @@ -1132,14 +1111,14 @@ jobs: wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model export PRMT="Once upon a time in a land far away" echo "Generate eager" - python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' + python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' echo "Generate compile" - python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile + python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile echo "Export and run ET (C++ runner)" - python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' + python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' ./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}" echo "Export and run AOTI (C++ runner)" - python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' + python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' ./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}" echo "Generate AOTI" python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}" diff --git a/docs/quantization.md b/docs/quantization.md index 348a3196e..bef7309c6 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -121,22 +121,29 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n ## Experimental TorchAO lowbit kernels ### Use -The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize. + +#### linear:a8wxdq +The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false). The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true). Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme. -You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization. +You should expect high performance on ARM CPU if groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization. + +#### embedding:wx +The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros. + +You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization. ### Setup -To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon. +To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon. From the torchchat root directory, run ``` sh torchchat/utils/scripts/build_torchao_ops.sh ``` -This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat. +This should take about 10 seconds to complete. Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners. @@ -156,17 +163,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif #### Eager mode ``` -OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5 +OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5 ``` #### torch.compile ``` -OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5 +OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5 ``` #### AOTI ``` -OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so +OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5 ``` @@ -178,7 +185,7 @@ OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cac #### ExecuTorch ``` -python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte +python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-pte llama3_1.pte ``` Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command. diff --git a/install/.pins/torchao-pin.txt b/install/.pins/torchao-pin.txt index a7422ea2e..40f083249 100644 --- a/install/.pins/torchao-pin.txt +++ b/install/.pins/torchao-pin.txt @@ -1 +1 @@ -49b1fb61c8b8eceda755579a2fd92c756d822de2 +c8f1174a06dcc0102849c8348ca6573bde8847a9 diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 8a708d416..31c639dfd 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -45,6 +45,7 @@ find_multiple, get_device_str, get_precision, + set_precision, name_to_dtype, state_dict_device, use_et_backend, @@ -52,7 +53,7 @@ # Flag for whether the a8wxdq quantizer is available. -a8wxdq_load_error: Optional[Exception] = None +torchao_experimental_load_error: Optional[Exception] = None ######################################################################### ### handle arg validation ### @@ -115,6 +116,13 @@ def quantize_model( if not support_tensor_subclass: unwrap_tensor_subclass(model) continue + + if quantizer in ["linear:a8wxdq", "embedding:wx"]: + # These quantizers require float32 input weights. Note that after quantization, + # the weights will no longer be float32, but lowbit integers + if get_precision() != torch.float32: + print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") + set_precision(torch.float32) # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat @@ -887,8 +895,9 @@ def quantized_model(self) -> nn.Module: try: import importlib.util - import sys import os + import sys + torchao_build_path = f"{os.getcwd()}/torchao-build" # Try loading quantizer @@ -896,15 +905,25 @@ def quantized_model(self) -> nn.Module: "torchao_experimental_quant_api", f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py", ) - torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec) + torchao_experimental_quant_api = importlib.util.module_from_spec( + torchao_experimental_quant_api_spec + ) sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api - torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api) - from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer - quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer + torchao_experimental_quant_api_spec.loader.exec_module( + torchao_experimental_quant_api + ) + from torchao_experimental_quant_api import ( + Int8DynActIntxWeightLinearQuantizer, + IntxWeightEmbeddingQuantizer, + ) + + quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer + quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer # Try loading custom op try: import glob + libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*") libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) torch.ops.load_library(libs[0]) @@ -915,8 +934,9 @@ def quantized_model(self) -> nn.Module: except Exception as e: class ErrorHandler(QuantHandler): def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): - global a8wxdq_load_error - raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}") + global torchao_experimental_load_error + raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}") - a8wxdq_load_error = e + torchao_experimental_load_error = e quantizer_class_dict["linear:a8wxdq"] = ErrorHandler + quantizer_class_dict["embedding:wx"] = ErrorHandler diff --git a/torchchat/utils/scripts/install_utils.sh b/torchchat/utils/scripts/install_utils.sh index 10405382e..84966cc35 100644 --- a/torchchat/utils/scripts/install_utils.sh +++ b/torchchat/utils/scripts/install_utils.sh @@ -191,7 +191,6 @@ install_torchao_aten_ops() { cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \ -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \ -DCMAKE_BUILD_TYPE="Release" \ - -DTORCHAO_OP_TARGET="aten" \ -S . \ -B ${CMAKE_OUT_DIR} -G Ninja cmake --build ${CMAKE_OUT_DIR} --target install --config Release @@ -207,7 +206,7 @@ install_torchao_executorch_ops() { cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \ -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \ -DCMAKE_BUILD_TYPE="Release" \ - -DTORCHAO_OP_TARGET="executorch" \ + -DTORCHAO_BUILD_EXECUTORCH_OPS=ON \ -DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \ -DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \ -S . \