From 8e18e7fc6875e28acfa43a0a3791ea2c40145af6 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Thu, 23 Jan 2025 18:52:16 -0800 Subject: [PATCH 1/4] Update generate.py Push backend manager into caller --- torchchat/generate.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 7f37386ac..d8ba560a2 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -532,7 +532,6 @@ def decode_n_tokens( callback=lambda _: _, eos_token_id: int = 2, eot_id: Optional[int] = None, - attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, **sampling_kwargs, ): new_tokens, new_probs = [], [] @@ -541,7 +540,8 @@ def decode_n_tokens( num_new_tokens - 1 ): # -1 to save space to run an EoS if dont generate it naturally # Actually better for Inductor to codegen attention here - with torch.nn.attention.sdpa_kernel([attention_backend]): + # with torch.nn.attention.sdpa_kernel([attention_backend]): + if True: # preserve indentation while testing out_token = cur_token.clone() next_token, next_prob = self.decode_one_token( @@ -685,7 +685,6 @@ def generate( sequential_prefill=True, callback=lambda x: x, max_seq_length: int, - attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, seed: Optional[int] = None, **sampling_kwargs, ) -> torch.Tensor: @@ -802,7 +801,6 @@ def generate( if self.is_llama3_model else None ), - attention_backend=attention_backend, **sampling_kwargs, ): generated_tokens.append(generated_token.view(-1)) @@ -1174,7 +1172,7 @@ def callback(x, *, done_generating=False): prof = torch.profiler.profile() t0 = time.perf_counter() num_tokens_generated = 0 - with prof: + with torch.nn.attention.sdpa_kernel([self.builder_args.attention_backend]), prof: generator_func = self.generate( self.model, encoded, @@ -1190,7 +1188,6 @@ def callback(x, *, done_generating=False): start_pos=start_pos, skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, - attention_backend=self.builder_args.attention_backend, ) if generator_args.chat_mode: start_pos += encoded.size(0) From 3c3b367aa35fab06487dd213511badfcd151ddde Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Thu, 23 Jan 2025 19:04:24 -0800 Subject: [PATCH 2/4] Update more-tests.yml Add tests for backends --- .github/workflows/more-tests.yml | 58 ++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/.github/workflows/more-tests.yml b/.github/workflows/more-tests.yml index f772382d1..b28b150f9 100644 --- a/.github/workflows/more-tests.yml +++ b/.github/workflows/more-tests.yml @@ -83,3 +83,61 @@ jobs: echo "tests complete" echo "******************************************" echo "::endgroup::" + + test-sdpa-backends: + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.4" + timeout: 60 + script: | + echo "::group::Print machine info" + uname -a + echo "::endgroup::" + + echo "::group::Download checkpoints" + # Install requirements + ./install/install_requirements.sh cuda + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + echo "::endgroup::" + + echo "::group::Download checkpoints" + mkdir -p checkpoints/stories15M + pushd checkpoints/stories15M + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt + wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + popd + echo "::endgroup::" + + echo "::group::Run inference" + export MODEL_PATH=checkpoints/stories15M/stories15M.pt + export MODEL_NAME=stories15M + export MODEL_DIR=/tmp + + for DEVICE in cpu cuda; do + for DTYPE in bfloat16 float16 float32; do + for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do + ################################################################### + # Python execution interpreted vanilla + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 + ################################################################### + # prefill, and compile and prefill compile + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --compile --compile-prefill + ################################################################### + # sequential prefill + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill + ################################################################### + # prefill, and compile + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile + done + done + done + + echo "tests complete" + echo "******************************************" + echo "::endgroup::" From e3933b2e0257614fa0cd57a6294f6cdd17ceb409 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Fri, 24 Jan 2025 12:31:20 -0800 Subject: [PATCH 3/4] Update more-tests.yml print out parameters during execution --- .github/workflows/more-tests.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/more-tests.yml b/.github/workflows/more-tests.yml index b28b150f9..c2502e7e4 100644 --- a/.github/workflows/more-tests.yml +++ b/.github/workflows/more-tests.yml @@ -19,6 +19,7 @@ jobs: gpu-arch-version: "12.4" timeout: 60 script: | + set -xeou pipefail echo "::group::Print machine info" uname -a echo "::endgroup::" @@ -95,6 +96,7 @@ jobs: gpu-arch-version: "12.4" timeout: 60 script: | + set -xeou pipefail echo "::group::Print machine info" uname -a echo "::endgroup::" @@ -122,6 +124,8 @@ jobs: for DEVICE in cpu cuda; do for DTYPE in bfloat16 float16 float32; do for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do + echo "******************************************************************" + echo "******* $DEVICE $DTYPE $SDPA " ################################################################### # Python execution interpreted vanilla python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 From ed8ab550df7cf3e2ce87101660db6fcc5283d436 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:30:26 -0800 Subject: [PATCH 4/4] Update generate.py Allow math as fallback --- torchchat/generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index d8ba560a2..6143df206 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1172,7 +1172,8 @@ def callback(x, *, done_generating=False): prof = torch.profiler.profile() t0 = time.perf_counter() num_tokens_generated = 0 - with torch.nn.attention.sdpa_kernel([self.builder_args.attention_backend]), prof: + # always allow math as fallback + with torch.nn.attention.sdpa_kernel([self.builder_args.attention_backend, torch.nn.attention.SDPBackend.MATH]), prof: generator_func = self.generate( self.model, encoded,