diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 3e92ed9c0..9d3ad63e3 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -621,71 +621,87 @@ jobs: python torchchat.py remove stories15m test-mps: - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - with: - runner: macos-m1-stable # neeps MPS, was macos-m1-stable - script: | - export PYTHON_VERSION="3.10" - set -x - # NS/MC: Remove previous installation of torch and torchao first - # as this script does not install anything into conda env but rather as system dep - pip3 uninstall -y torch || true - set -eou pipefail - - pip3 uninstall -y torchao || true - set -eou pipefail - - echo "::group::Print machine info" - uname -a - sysctl machdep.cpu.brand_string - sysctl machdep.cpu.core_count - echo "::endgroup::" + strategy: + matrix: + runner: [macos-m1-stable ] + runs-on: ${{matrix.runner}} + steps: + - name: Checkout repo + uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.10.11 + - name: Print machine info + run: | + uname -a + if [ $(uname -s) == Darwin ]; then + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + fi + - name: Run test + run: | + export PYTHON_VERSION="3.10" + set -x + # NS/MC: Remove previous installation of torch and torchao first + # as this script does not install anything into conda env but rather as system dep + pip3 uninstall -y torch || true + set -eou pipefail - echo "::group::Install requirements" - # Install requirements - ./install/install_requirements.sh - ls -la - pwd - pip3 list - python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' - echo "::endgroup::" + pip3 uninstall -y torchao || true + set -eou pipefail - echo "::group::Download checkpoints" - ( - mkdir -p checkpoints/stories15M - pushd checkpoints/stories15M - curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt - curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model - popd - ) - echo "::endgroup::" + echo "::group::Print machine info" + uname -a + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + echo "::endgroup::" - echo "::group::Run inference" - export MODEL_PATH=checkpoints/stories15M/stories15M.pt - export MODEL_NAME=stories15M - export MODEL_DIR=/tmp + echo "::group::Install requirements" + # Install requirements + ./install/install_requirements.sh + ls -la + pwd + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + echo "::endgroup::" + + echo "::group::Download checkpoints" + ( + mkdir -p checkpoints/stories15M + pushd checkpoints/stories15M + curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt + curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + popd + ) + echo "::endgroup::" + + echo "::group::Run inference" + export MODEL_PATH=checkpoints/stories15M/stories15M.pt + export MODEL_NAME=stories15M + export MODEL_DIR=/tmp - python3 torchchat.py generate --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 - echo "************************************************************" - echo "*** embedding" - echo "************************************************************" + echo "************************************************************" + echo "*** embedding" + echo "************************************************************" - python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - echo "************************************************************" - echo "*** linear int8" - echo "************************************************************" + echo "************************************************************" + echo "*** linear int8" + echo "************************************************************" - python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - echo "************************************************************" - echo "*** linear int4" - echo "************************************************************" + echo "************************************************************" + echo "*** linear int4" + echo "************************************************************" - PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 test-gguf-util: strategy: matrix: @@ -734,66 +750,82 @@ jobs: echo "Tests complete." test-mps-dtype: - uses: pytorch/test-infra/.github/workflows/macos_job.yml@main - with: - runner: macos-m1-stable # needs MPS, was macos-m1-stable - script: | - export PYTHON_VERSION="3.10" - set -x - # NS/MC: Remove previous installation of torch and torchao first - # as this script does not install anything into conda env but rather as system dep - pip3 uninstall -y torch || true - set -eou pipefail - - pip3 uninstall -y torchao || true - set -eou pipefail - - echo "::group::Print machine info" - uname -a - sysctl machdep.cpu.brand_string - sysctl machdep.cpu.core_count - echo "::endgroup::" + strategy: + matrix: + runner: [macos-m1-stable ] + runs-on: ${{matrix.runner}} + steps: + - name: Checkout repo + uses: actions/checkout@v2 + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.10.11 + - name: Print machine info + run: | + uname -a + if [ $(uname -s) == Darwin ]; then + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + fi + - name: Run test + run: | + export PYTHON_VERSION="3.10" + set -x + # NS/MC: Remove previous installation of torch and torchao first + # as this script does not install anything into conda env but rather as system dep + pip3 uninstall -y torch || true + set -eou pipefail - echo "::group::Install requirements" - # Install requirements - ./install/install_requirements.sh - ls -la - pwd - pip3 list - python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' - echo "::endgroup::" + pip3 uninstall -y torchao || true + set -eou pipefail - echo "::group::Download checkpoints" - ( - mkdir -p checkpoints/stories15M - pushd checkpoints/stories15M - curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt - curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model - popd - ) - echo "::endgroup::" + echo "::group::Print machine info" + uname -a + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + echo "::endgroup::" - echo "::group::Run inference" - export MODEL_PATH=checkpoints/stories15M/stories15M.pt - export MODEL_NAME=stories15M - export MODEL_DIR=/tmp - for DTYPE in float16 float32; do - # if [ $(uname -s) == Darwin ]; then - # export DTYPE=float16 - # fi + echo "::group::Install requirements" + # Install requirements + ./install/install_requirements.sh + ls -la + pwd + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + echo "::endgroup::" + + echo "::group::Download checkpoints" + ( + mkdir -p checkpoints/stories15M + pushd checkpoints/stories15M + curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt + curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + popd + ) + echo "::endgroup::" + + echo "::group::Run inference" + export MODEL_PATH=checkpoints/stories15M/stories15M.pt + export MODEL_NAME=stories15M + export MODEL_DIR=/tmp + for DTYPE in float16 float32; do + # if [ $(uname -s) == Darwin ]; then + # export DTYPE=float16 + # fi - python3 torchchat.py generate --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --dtype ${DTYPE} --device mps --checkpoint-path ${MODEL_PATH} --temperature 0 - python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 - done + PYTORCH_ENABLE_MPS_FALLBACK=1 python3 torchchat.py generate --dtype ${DTYPE} --device mps --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 + done compile-gguf: strategy: matrix: @@ -918,11 +950,11 @@ jobs: - name: Install ExecuTorch python run: | echo "Install ExecuTorch python" - pushd et-build/src/executorch - chmod +x ./install_requirements.sh - chmod +x ./install_requirements.py - ./install_requirements.sh - popd + export TORCHCHAT_ROOT=$PWD + export ET_BUILD_DIR="et-build" + ENABLE_ET_PYBIND="${1:-true}" + source "torchchat/utils/scripts/install_utils.sh" + install_executorch_python_libs $ENABLE_ET_PYBIND - name: Install runner run: | echo "Installing runner" @@ -1023,3 +1055,92 @@ jobs: git submodule update --init ./runner/build_android.sh echo "Tests complete." + + test-torchao-experimental: + strategy: + matrix: + runner: [macos-14-xlarge] + runs-on: ${{matrix.runner}} + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.10.11 + - name: Setup Xcode + if: runner.os == 'macOS' + uses: maxim-lobanov/setup-xcode@v1 + with: + xcode-version: '15.3' + - name: Print machine info + run: | + uname -a + if [ $(uname -s) == Darwin ]; then + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + fi + - name: Install torchchat + run: | + echo "Intalling pip3 packages" + ./install/install_requirements.sh + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + - name: Install torchao-ops + id: install-torchao-ops + run: | + bash torchchat/utils/scripts/build_torchao_ops.sh + - name: Set git shas + id: setup-hash + run: | + export TORCHCHAT_ROOT=${PWD} + echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV" + - name: Load or install ET + id: install-et + uses: actions/cache@v4 + with: + path: | + ./et-build + ./torchchat/utils/scripts + key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }} + - if: ${{ steps.install-et.outputs.cache-hit != 'true' }} + continue-on-error: true + run: | + echo "Installing ExecuTorch" + bash torchchat/utils/scripts/install_et.sh + - name: Install ExecuTorch python + run: | + echo "Install ExecuTorch python" + export TORCHCHAT_ROOT=$PWD + export ET_BUILD_DIR="et-build" + ENABLE_ET_PYBIND="${1:-true}" + source "torchchat/utils/scripts/install_utils.sh" + install_executorch_python_libs $ENABLE_ET_PYBIND + - name: Install runner + run: | + echo "Installing runner" + bash torchchat/utils/scripts/build_native.sh et link_torchao_ops + - name: Install runner AOTI + id: install-runner-aoti + run: | + bash torchchat/utils/scripts/build_native.sh aoti link_torchao_ops + - name: Run inference + run: | + python torchchat.py download stories110M + wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + export PRMT="Once upon a time in a land far away" + echo "Generate eager" + python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' + echo "Generate compile" + python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile + echo "Export and run ET (C++ runner)" + python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' + ./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}" + echo "Export and run AOTI (C++ runner)" + python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' + ./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}" + echo "Generate AOTI" + python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}" + echo "Tests complete." diff --git a/.gitignore b/.gitignore index 044bad856..74d0a28fa 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ __pycache__/ # Build directories build/android/* et-build/* +torchao-build/* runner-et/cmake-out/* runner-aoti/cmake-out/* cmake-out/ diff --git a/docs/quantization.md b/docs/quantization.md index 1f619e58e..c0899adee 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -118,6 +118,75 @@ python3 torchchat.py export llama3 --quantize '{"embedding": {"bitwidth": 4, "gr python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my name is" ``` +## Experimental TorchAO lowbit kernels + +### Use +The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize. +It takes arguments bitwidth (2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false). +The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true). +Roughly speaking, {bitwidth: 4, groupsize: 256, has_weight_zeros: false} is similar to GGML's Q40 quantization scheme. + +You should expect high performance on ARM CPU if bitwidth is 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization. + +### Setup +To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon. + +From the torchchat root directory, run +``` +sh torchchat/utils/scripts/build_torchao_ops.sh +``` + +This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat. + +Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao when running the scripts the build the runners. + +``` +sh torchchat/utils/scripts/build_native.sh aoti link_torchao_ops +``` + +``` +sh torchchat/utils/scripts/build_native.sh et link_torchao_ops +``` + +Note before running `sh torchchat/utils/scripts/build_native.sh et link_torchao_ops`, you must first install executorch with `sh torchchat/utils/scripts/install_et.sh` if you have not done so already. + +### Examples + +Below we show how to use the new kernels. Except for ExecuTorch, you can specify the number of threads used by setting OMP_NUM_THREADS (as is the case with PyTorch in general). Doing so is optional and a default number of threads will be chosen automatically if you do not specify. + +#### Eager mode +``` +OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5 +``` + +#### torch.compile +``` +OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5 +``` + +#### AOTI +``` +OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so +OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5 +``` + +If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner: + +``` +OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time," +``` + +#### ExecuTorch +``` +python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte +``` + +Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command. + +``` +./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time," +``` + ## Quantization Profiles Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json` diff --git a/install/.pins/torchao-pin.txt b/install/.pins/torchao-pin.txt new file mode 100644 index 000000000..b28bd09cd --- /dev/null +++ b/install/.pins/torchao-pin.txt @@ -0,0 +1 @@ +63cb7a9857654784f726fec75c0dc36167094d8a diff --git a/runner/aoti.cmake b/runner/aoti.cmake index 156e9bcce..082a6f5ce 100644 --- a/runner/aoti.cmake +++ b/runner/aoti.cmake @@ -28,3 +28,7 @@ if(Torch_FOUND) target_link_libraries(aoti_run "${TORCH_LIBRARIES}" m) set_property(TARGET aoti_run PROPERTY CXX_STANDARD 17) endif() + +if (LINK_TORCHAO_OPS) + target_link_libraries(aoti_run "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/liblinear_a8wxdq_ATEN${CMAKE_SHARED_LIBRARY_SUFFIX}") +endif() diff --git a/runner/et.cmake b/runner/et.cmake index 99e67a025..c788ead56 100644 --- a/runner/et.cmake +++ b/runner/et.cmake @@ -116,6 +116,14 @@ if(executorch_FOUND) target_link_libraries(et_run PRIVATE log) endif() + if(LINK_TORCHAO_OPS) + target_link_libraries(et_run PRIVATE "$") + target_link_libraries(et_run PRIVATE + "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_kernels_aarch64.a" + "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_ops_linear_EXECUTORCH.a" + ) + endif() + else() MESSAGE(WARNING "ExecuTorch package not found") endif() diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index a0d9248a9..77b03fcba 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -96,10 +96,19 @@ def quantize_model( precision = get_precision() try: - # Easier to ask forgiveness than permission - quant_handler = ao_quantizer_class_dict[quantizer]( - groupsize=q_kwargs["groupsize"], device=device, precision=precision - ) + if quantizer == "linear:a8wxdq": + quant_handler = ao_quantizer_class_dict[quantizer]( + device=device, + precision=precision, + bitwidth=q_kwargs.get("bitwidth", 4), + groupsize=q_kwargs.get("groupsize", 128), + has_weight_zeros=q_kwargs.get("has_weight_zeros", False), + ) + else: + # Easier to ask forgiveness than permission + quant_handler = ao_quantizer_class_dict[quantizer]( + groupsize=q_kwargs["groupsize"], device=device, precision=precision + ) except TypeError as e: if "unexpected keyword argument 'device'" in str(e): quant_handler = ao_quantizer_class_dict[quantizer]( @@ -861,3 +870,33 @@ def quantized_model(self) -> nn.Module: "linear:int4": Int4WeightOnlyQuantizer, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, } + +try: + import importlib.util + import sys + import os + torchao_build_path = f"{os.getcwd()}/torchao-build" + + # Try loading quantizer + torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location( + "torchao_experimental_quant_api", + f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py", + ) + torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec) + sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api + torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api) + from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer + ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer + + # Try loading custom op + try: + import glob + libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/liblinear_a8wxdq_ATEN.*") + libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) + torch.ops.load_library(libs[0]) + except Exception as e: + print("Failed to torchao ops library with error: ", e) + print("Slow fallback kernels will be used.") + +except Exception as e: + print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}") diff --git a/torchchat/utils/scripts/build_native.sh b/torchchat/utils/scripts/build_native.sh index 924b86a65..3c2c1c846 100755 --- a/torchchat/utils/scripts/build_native.sh +++ b/torchchat/utils/scripts/build_native.sh @@ -26,6 +26,7 @@ if [ $# -eq 0 ]; then exit 1 fi +LINK_TORCHAO_OPS=OFF while (( "$#" )); do case "$1" in -h|--help) @@ -42,6 +43,11 @@ while (( "$#" )); do TARGET="et" shift ;; + link_torchao_ops) + echo "Linking with torchao ops..." + LINK_TORCHAO_OPS=ON + shift + ;; *) echo "Invalid option: $1" show_help @@ -66,14 +72,28 @@ if [[ "$TARGET" == "et" ]]; then echo "Make sure you run install_executorch_libs" exit 1 fi + + if [[ "$LINK_TORCHAO_OPS" == "ON" ]]; then + if [ ! -d "${TORCHCHAT_ROOT}/torchao-build" ]; then + echo "Directory ${TORCHCHAT_ROOT}/torchao-build does not exist." + echo "Make sure you run clone_torchao" + exit 1 + fi + + source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh" + find_cmake_prefix_path + EXECUTORCH_INCLUDE_DIRS="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/include;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src" + EXECUTORCH_LIBRARIES="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libexecutorch_no_prim_ops.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libextension_threadpool.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libcpuinfo.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libpthreadpool.a" + install_torchao_executorch_ops + fi fi popd # CMake commands if [[ "$TARGET" == "et" ]]; then - cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja + cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_OPS="${LINK_TORCHAO_OPS}" -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja else - cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja + cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_OPS="${LINK_TORCHAO_OPS}" -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja fi cmake --build ./cmake-out --target "${TARGET}"_run diff --git a/torchchat/utils/scripts/build_torchao_ops.sh b/torchchat/utils/scripts/build_torchao_ops.sh new file mode 100644 index 000000000..a8fd8bea2 --- /dev/null +++ b/torchchat/utils/scripts/build_torchao_ops.sh @@ -0,0 +1,16 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + + +source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh" + +pushd ${TORCHCHAT_ROOT} +find_cmake_prefix_path +clone_torchao +install_torchao_aten_ops +popd diff --git a/torchchat/utils/scripts/install_et.sh b/torchchat/utils/scripts/install_et.sh index 04db3b287..8062a8316 100755 --- a/torchchat/utils/scripts/install_et.sh +++ b/torchchat/utils/scripts/install_et.sh @@ -19,10 +19,4 @@ pushd ${TORCHCHAT_ROOT} find_cmake_prefix_path clone_executorch install_executorch_libs $ENABLE_ET_PYBIND -install_executorch_python_libs $ENABLE_ET_PYBIND -# TODO: figure out the root cause of 'AttributeError: module 'evaluate' -# has no attribute 'utils'' error from evaluate CI jobs and remove -# `import lm_eval` from torchchat.py since it requires a specific version -# of numpy. -pip install numpy=='1.26.4' popd diff --git a/torchchat/utils/scripts/install_utils.sh b/torchchat/utils/scripts/install_utils.sh index 0ff4608c6..ec9677373 100644 --- a/torchchat/utils/scripts/install_utils.sh +++ b/torchchat/utils/scripts/install_utils.sh @@ -93,6 +93,13 @@ install_executorch_python_libs() { echo "Installing pybind" bash ./install_requirements.sh --pybind xnnpack fi + + # TODO: figure out the root cause of 'AttributeError: module 'evaluate' + # has no attribute 'utils'' error from evaluate CI jobs and remove + # `import lm_eval` from torchchat.py since it requires a specific version + # of numpy. + pip install numpy=='1.26.4' + pip3 list popd } @@ -161,3 +168,51 @@ install_executorch_libs() { install_executorch_cpp_libs install_executorch_python_libs $1 } + +clone_torchao() { + echo "Cloning torchao to ${TORCHCHAT_ROOT}/torchao-build/src" + rm -rf ${TORCHCHAT_ROOT}/torchao-build + mkdir -p ${TORCHCHAT_ROOT}/torchao-build/src + pushd ${TORCHCHAT_ROOT}/torchao-build/src + echo $pwd + + git clone https://github.com/pytorch/ao.git + cd ao + git checkout $(cat ${TORCHCHAT_ROOT}/intstall/.pins/torchao-pin.txt) + + popd +} + +install_torchao_aten_ops() { + echo "Building torchao custom ops for ATen" + pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental + + CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out + cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \ + -DCMAKE_BUILD_TYPE="Release" \ + -DTORCHAO_OP_TARGET="ATEN" \ + -S . \ + -B ${CMAKE_OUT_DIR} -G Ninja + cmake --build ${CMAKE_OUT_DIR} --target install --config Release + + popd +} + +install_torchao_executorch_ops() { + echo "Building torchao custom ops for ExecuTorch" + pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental + + CMAKE_OUT_DIR="${TORCHCHAT_ROOT}/torchao-build/cmake-out" + cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \ + -DCMAKE_BUILD_TYPE="Release" \ + -DTORCHAO_OP_TARGET="EXECUTORCH" \ + -DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \ + -DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \ + -S . \ + -B ${CMAKE_OUT_DIR} -G Ninja + cmake --build ${CMAKE_OUT_DIR} --target install --config Release + + popd +}