Skip to content

Commit

Permalink
Update on "add eval for attention sink"
Browse files Browse the repository at this point in the history
This PR adds the function to evaluate the model's perplexity when AttentionSink is enabled.

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py which is used by the AttentionSink paper to evaluate the model's perplexity when AttentionSink is enabled.

Differential Revision: [D66474732](https://our.internmc.facebook.com/intern/diff/D66474732/)

Perplexity measured for llama 3.2 1B and 1B_Instruct model up to 40k tokens with AttentionSink enabled:

<img width="966" alt="Screenshot 2024-11-25 at 2 46 04 PM" src="https://github.com/user-attachments/assets/ba7118f9-b5d7-4de8-b1fa-7d2ba0646515">


[ghstack-poisoned]
  • Loading branch information
helunwencser committed Dec 2, 2024
2 parents a3b8d91 + 0574fe0 commit 38d9e1c
Show file tree
Hide file tree
Showing 96 changed files with 5,202 additions and 714 deletions.
11 changes: 7 additions & 4 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ UPLOAD_DIR="${UPLOAD_DIR:-}"
# Default PT2E_QUANTIZE to empty string if not set
PT2E_QUANTIZE="${PT2E_QUANTIZE:-}"

# Default CMake Build Type to release mode
CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release}

if [[ $# -lt 4 ]]; then # Assuming 4 mandatory args
echo "Expecting atleast 4 positional arguments"
echo "Usage: [...]"
Expand Down Expand Up @@ -143,7 +146,7 @@ cmake_install_executorch_libraries() {
rm -rf cmake-out
retry cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
Expand All @@ -157,22 +160,22 @@ cmake_install_executorch_libraries() {
-DQNN_SDK_ROOT="$QNN_SDK_ROOT" \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-Bcmake-out .
cmake --build cmake-out -j9 --target install --config Debug
cmake --build cmake-out -j9 --target install --config "$CMAKE_BUILD_TYPE"
}

cmake_build_llama_runner() {
echo "Building llama runner"
dir="examples/models/llama"
retry cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Debug \
-DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM="$CUSTOM" \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-Bcmake-out/${dir} \
${dir}
cmake --build cmake-out/${dir} -j9 --config Debug
cmake --build cmake-out/${dir} -j9 --config "$CMAKE_BUILD_TYPE"

}

Expand Down
16 changes: 8 additions & 8 deletions .ci/scripts/test_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
set -exu
# shellcheck source=/dev/null

BUILD_TYPE=${1:-Debug}
TARGET_OS=${2:-Native}
BUILD_DIR=${3:-cmake-out}
CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release}

echo "Building with BUILD_TYPE: $BUILD_TYPE, TARGET_OS: $TARGET_OS, BUILD_DIR: $BUILD_DIR"
echo "Building with CMAKE_BUILD_TYPE: $CMAKE_BUILD_TYPE, TARGET_OS: $TARGET_OS, BUILD_DIR: $BUILD_DIR"

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
Expand All @@ -32,7 +32,7 @@ if hash nproc &> /dev/null; then NPROC=$(nproc); fi

EXECUTORCH_COMMON_CMAKE_ARGS=" \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
Expand All @@ -49,7 +49,7 @@ cmake_install_executorch_libraries() {
${EXECUTORCH_COMMON_CMAKE_ARGS} \
-B${BUILD_DIR} .

cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${BUILD_TYPE}
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
}

cmake_install_executorch_libraries_for_android() {
Expand All @@ -59,14 +59,14 @@ cmake_install_executorch_libraries_for_android() {
${EXECUTORCH_COMMON_CMAKE_ARGS} \
-B${BUILD_DIR} .

cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${BUILD_TYPE}
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
}


LLAVA_COMMON_CMAKE_ARGS=" \
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
-DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON"
Expand All @@ -81,7 +81,7 @@ cmake_build_llava_runner() {
-B${BUILD_DIR}/${dir} \
${dir}

cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${BUILD_TYPE}
cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${CMAKE_BUILD_TYPE}
}


Expand All @@ -98,7 +98,7 @@ cmake_build_llava_runner_for_android() {
-B${BUILD_DIR}/${dir} \
${dir}

cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${BUILD_TYPE}
cmake --build ${BUILD_DIR}/${dir} -j${NPROC} --config ${CMAKE_BUILD_TYPE}
}

# only export the one without custom op for now since it's
Expand Down
16 changes: 1 addition & 15 deletions .github/workflows/ghstack_land.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,7 @@ on:
pull_request:
types: [closed]
branches:
- 'gh/cccclai/[0-9]+/base'
- 'gh/dbort/[0-9]+/base'
- 'gh/dvorjackz/[0-9]+/base'
- 'gh/guangy10/[0-9]+/base'
- 'gh/helunwencser/[0-9]+/base'
- 'gh/jorgep31415/[0-9]+/base'
- 'gh/kimishpatel/[0-9]+/base'
- 'gh/kirklandsign/[0-9]+/base'
- 'gh/larryliu0820/[0-9]+/base'
- 'gh/lucylq/[0-9]+/base'
- 'gh/manuelcandales/[0-9]+/base'
- 'gh/mcr229/[0-9]+/base'
- 'gh/swolchok/[0-9]+/base'
- 'gh/SS-JIA/[0-9]+/base'
- 'gh/trivedivivek/[0-9]+/base'
- 'gh/*/[0-9]+/base'

jobs:
ghstack_merge_to_main:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ jobs:
# ${CONDA_RUN} python -m unittest examples.models.llava.test.test_llava

# # run e2e (export, tokenizer and runner)
# PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llava.sh Release
# PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llava.sh

test-qnn-model:
name: test-qnn-model
Expand Down
56 changes: 16 additions & 40 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,22 @@ if(EXECUTORCH_BUILD_PTHREADPOOL
endif()

if(EXECUTORCH_BUILD_PYBIND)
# Setup RPATH.
# See https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling
if(APPLE)
set(CMAKE_MACOSX_RPATH ON)
set(_rpath_portable_origin "@loader_path")
else()
set(_rpath_portable_origin $ORIGIN)
endif(APPLE)
# Use separate rpaths during build and install phases
set(CMAKE_SKIP_BUILD_RPATH FALSE)
# Don't use the install-rpath during the build phase
set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE)
set(CMAKE_INSTALL_RPATH "${_rpath_portable_origin}")
# Automatically add all linked folders that are NOT in the build directory to
# the rpath (per library?)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pybind11)

if(NOT EXECUTORCH_BUILD_EXTENSION_DATA_LOADER)
Expand Down Expand Up @@ -765,46 +781,6 @@ if(EXECUTORCH_BUILD_PYBIND)
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
target_link_libraries(portable_lib PRIVATE ${_dep_libs})
if(APPLE)
# pip wheels will need to be able to find the torch libraries. On Linux, the
# .so has non-absolute dependencies on libs like "libtorch.so" without
# paths; as long as we `import torch` first, those dependencies will work.
# But Apple dylibs do not support non-absolute dependencies, so we need to
# tell the loader where to look for its libraries. The LC_LOAD_DYLIB entries
# for the torch libraries will look like "@rpath/libtorch.dylib", so we can
# add an LC_RPATH entry to look in a directory relative to the installed
# location of our _portable_lib.so file. To see these LC_* values, run
# `otool -l _portable_lib*.so`.
set_target_properties(
portable_lib
PROPERTIES # Assume that this library will be installed in
# `site-packages/executorch/extension/pybindings`, and that
# the torch libs are in `site-packages/torch/lib`.
BUILD_RPATH "@loader_path/../../../torch/lib"
INSTALL_RPATH "@loader_path/../../../torch/lib"
# Assume <executorch> is the root `site-packages/executorch`
# Need to add <executorch>/extension/llm/custom_ops for
# libcustom_ops_aot_lib.dylib
BUILD_RPATH "@loader_path/../../extension/llm/custom_ops"
INSTALL_RPATH "@loader_path/../../extension/llm/custom_ops"
# Need to add <executorch>/kernels/quantized for
# libquantized_ops_aot_lib.dylib
BUILD_RPATH "@loader_path/../../kernels/quantized"
INSTALL_RPATH "@loader_path/../../kernels/quantized"
)
else()
set_target_properties(
portable_lib
PROPERTIES
# Assume <executorch> is the root `site-packages/executorch`
# Need to add <executorch>/extension/llm/custom_ops for
# libcustom_ops_aot_lib
# Need to add <executorch>/kernels/quantized for
# libquantized_ops_aot_lib
BUILD_RPATH
"$ORIGIN:$ORIGIN/../../extension/llm/custom_ops:$ORIGIN/../../kernels/quantized"
)
endif()

install(TARGETS portable_lib
LIBRARY DESTINATION executorch/extension/pybindings
Expand Down
11 changes: 11 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,14 @@ python_library(
"//executorch/backends/arm/operators:node_visitor",
],
)

python_library(
name = "arm_model_evaluator",
src = [
"util/arm_model_evaluator.py",
],
typing = True,
deps = [
"//caffe2:torch",
]
)
6 changes: 3 additions & 3 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
Expand Down Expand Up @@ -71,7 +71,7 @@ def transform_to_backend_pipeline(
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
Expand Down
58 changes: 58 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-unsafe

from inspect import isclass
from typing import Optional

import torch
Expand Down Expand Up @@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
fake_tensor, FakeTensor
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
return fake_tensor


def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
"""
Help-function for getting a value from node.args/ kwargs, three cases:
1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds
2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist
3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique.
"""
if isinstance(key, int):
if 0 <= key < len(args):
return args[key]
elif key == len(args):
if default_value is not None:
return default_value
else:
raise RuntimeError(f"No defult value given for index {key}")
else:
raise RuntimeError(
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
return args.get(key, default_value)
elif isclass(key):
for arg in args:
if isinstance(arg, key):
return arg
if default_value is not None:
return default_value
else:
raise RuntimeError(f"No arg of type {key}")
else:
raise RuntimeError("Invalid type")


def set_node_arg(node: torch.fx.Node, i: int | str, value):
"""
Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list.
"""
if isinstance(i, int):
if 0 <= i < len(node.args):
args = list(node.args)
args[i] = value
node.args = tuple(args)
return
elif i == len(node.args):
node.args = node.args + (value,)
else:
raise RuntimeError(
f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})"
)
elif isinstance(i, str):
kwargs = dict(node.kwargs)
kwargs[i] = value
node.kwargs = kwargs
else:
raise RuntimeError("Invalid type")
13 changes: 7 additions & 6 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
return super().call_operator(op, args, kwargs, meta)

x = args[0]
dim = args[1]
keepdim = args[2] if len(args) > 2 else False
if not keepdim:
return super().call_operator(op, args, kwargs, meta)
# if keepdim == True and dim == [-1, -2], mean.dim can be
x = get_node_arg(args, 0)
dim = get_node_arg(args, 1)
keepdim = get_node_arg(args, 2, False)

# if dim == [-1, -2], mean.dim can be
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
if dim == [-1, -2]:
# Simply return the mean.dim operator for future decomposition.
return super().call_operator(op, args, kwargs, meta)

shape = meta["val"].size()
dtype = meta["val"].dtype
input_shape = x.data.size()
Expand Down
27 changes: 16 additions & 11 deletions backends/arm/_passes/decompose_var_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import torch
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

Expand Down Expand Up @@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta):
torch.ops.aten.var.dim,
):
return super().call_operator(op, args, kwargs, meta)
shape = meta["val"].size()

x = args[0]
input_shape = x.data.size()
shape = list(meta["val"].size())
if shape == []:
shape = [1 for _ in input_shape]

dtype = meta["val"].dtype
dim = args[1] if len(args) > 1 else list(range(len(shape)))
# Get dim from args based on argument type
dim = get_node_arg(args, key=list, default_value=list(range(len(shape))))

if op == torch.ops.aten.var.dim:
correction = args[-2]
keepdim = args[-1]
keepdim = get_node_arg(args, bool, False)
correction = get_node_arg(args, int, 1)
else:
correction = kwargs["correction"]
keepdim = kwargs.get("keepdim", False)
if not keepdim:
return super().call_operator(op, args, kwargs, meta)
correction = get_node_arg(kwargs, "correction", 1)
keepdim = get_node_arg(kwargs, "keepdim", False)

x = args[0]
input_shape = x.data.size()
N = 1
for d in dim:
N *= input_shape[d]

mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
diff = super().call_operator(diff_op, (x, mean), {}, meta)
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)
Expand Down
Loading

0 comments on commit 38d9e1c

Please sign in to comment.