From 8451b0df3d90b9dfde41869ecc5caf30c6f419db Mon Sep 17 00:00:00 2001 From: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:54:33 -0700 Subject: [PATCH] Integrate flex attention (#1193) --- recipes/full_finetune_distributed.py | 20 +- recipes/full_finetune_single_device.py | 18 +- recipes/lora_finetune_distributed.py | 18 +- recipes/lora_finetune_single_device.py | 18 +- recipes/qat_distributed.py | 18 +- tests/torchtune/config/test_config_utils.py | 8 +- tests/torchtune/data/test_collate.py | 118 ++++++++ .../torchtune/datasets/test_packed_dataset.py | 83 ++---- .../torchtune/modules/test_attention_utils.py | 150 +++++++++++ tests/torchtune/utils/test_logging.py | 41 ++- torchtune/config/_utils.py | 12 +- torchtune/data/__init__.py | 2 + torchtune/data/_collate.py | 63 +++++ torchtune/data/_common.py | 4 + torchtune/datasets/_packed.py | 68 ++--- .../models/llama3/_component_builders.py | 3 +- torchtune/modules/__init__.py | 3 + torchtune/modules/attention.py | 31 ++- torchtune/modules/attention_utils.py | 252 ++++++++++++++++++ torchtune/modules/transformer.py | 54 ++-- torchtune/utils/logging.py | 34 +++ 21 files changed, 817 insertions(+), 201 deletions(-) create mode 100644 tests/torchtune/modules/test_attention_utils.py create mode 100644 torchtune/modules/attention_utils.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 14afe984aa..f3c55bcbbb 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -20,7 +20,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_sft +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY @@ -227,7 +227,7 @@ def setup(self, cfg: DictConfig) -> None: self._loss_fn = config.instantiate(cfg.loss) if self._compile: - training.compile_loss(self.loss_fn, verbose=self._is_rank_zero) + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model @@ -491,14 +491,14 @@ def _setup_data( dataset=ds, batch_size=batch_size, sampler=sampler, - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else None + collate_fn=partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, ), ) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index fb668fa160..94f804ef85 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -18,7 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_sft +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY @@ -451,14 +451,14 @@ def _setup_data( dataset=ds, batch_size=batch_size, sampler=sampler, - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else None + collate_fn=partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, ), ) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 8319ee3eb7..f635da8d52 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -20,7 +20,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_sft +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( DoRALinear, @@ -559,14 +559,14 @@ def _setup_data( dataset=ds, batch_size=batch_size, sampler=sampler, - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else None + collate_fn=partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, ), ) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 6f11b57160..0862675a77 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_sft +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, @@ -486,14 +486,14 @@ def _setup_data( dataset=ds, sampler=sampler, batch_size=batch_size, - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else None + collate_fn=partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, ), ) diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 8c34e5f961..7bbef2fcc8 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -21,7 +21,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_sft +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY @@ -523,14 +523,14 @@ def _setup_data( dataset=ds, batch_size=batch_size, sampler=sampler, - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else None + collate_fn=partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, ), ) diff --git a/tests/torchtune/config/test_config_utils.py b/tests/torchtune/config/test_config_utils.py index 37d90c49e1..2924ddfe5b 100644 --- a/tests/torchtune/config/test_config_utils.py +++ b/tests/torchtune/config/test_config_utils.py @@ -131,13 +131,13 @@ def test_log_config(self, capsys): with mock.patch( "torchtune.config._utils.get_logger", return_value=logger ), mock.patch( - "torchtune.config._utils.dist.is_available", return_value=True + "torchtune.utils.logging.dist.is_available", return_value=True ), mock.patch( - "torchtune.config._utils.dist.is_initialized", return_value=True + "torchtune.utils.logging.dist.is_initialized", return_value=True ): # Make sure rank 0 logs as expected with mock.patch( - "torchtune.config._utils.dist.get_rank", + "torchtune.utils.logging.dist.get_rank", return_value=0, ): log_config("test", cfg) @@ -153,7 +153,7 @@ def test_log_config(self, capsys): # Make sure all other ranks do not log anything with mock.patch( - "torchtune.config._utils.dist.get_rank", + "torchtune.utils.logging.dist.get_rank", return_value=1, ): log_config("test", cfg) diff --git a/tests/torchtune/data/test_collate.py b/tests/torchtune/data/test_collate.py index a1dfbf5a4b..e697d27f47 100644 --- a/tests/torchtune/data/test_collate.py +++ b/tests/torchtune/data/test_collate.py @@ -6,14 +6,19 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +from unittest import mock + import pytest import torch +from tests.test_utils import gpu_test from torchtune.data import ( left_pad_sequence, padded_collate, padded_collate_dpo, + padded_collate_packed, padded_collate_sft, ) +from torchtune.modules.attention_utils import _SUPPORTS_FLEX_ATTENTION class TestPaddedCollateSFT: @@ -47,6 +52,119 @@ def test_batch_pad_sequence(self): padded_label, torch.tensor([10, ignore_idx, ignore_idx]) ) + @mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False) + def test_padded_collate_packed_sdpa(self): + token_pairs = [ + { + "tokens": torch.tensor([1, 2, 3, 4, 5, 6]), + "labels": torch.tensor([7, 8, 9, 10, 11, 12]), + "input_pos": torch.tensor([0, 1, 2, 0, 1, 0]), + "seq_lens": torch.tensor([3, 2, 1]), + }, + { + "tokens": torch.tensor([13, 14, 15, 16, 17, 18]), + "labels": torch.tensor([19, 20, 21, 22, 23, 24]), + "input_pos": torch.tensor([0, 1, 0, 1, 0, 1]), + "seq_lens": torch.tensor([2, 2, 2]), + }, + ] + collated = padded_collate_packed( + batch=token_pairs, + ) + torch.testing.assert_close( + collated["tokens"], + torch.tensor([[1, 2, 3, 4, 5, 6], [13, 14, 15, 16, 17, 18]]), + ) + torch.testing.assert_close( + collated["labels"], + torch.tensor([[7, 8, 9, 10, 11, 12], [19, 20, 21, 22, 23, 24]]), + ) + torch.testing.assert_close( + collated["input_pos"], + torch.tensor([[0, 1, 2, 0, 1, 0], [0, 1, 0, 1, 0, 1]]), + ) + torch.testing.assert_close( + collated["mask"], + torch.tensor( + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 1], + ], + ], + dtype=torch.bool, + ), + ) + + @pytest.mark.skipif( + not _SUPPORTS_FLEX_ATTENTION, + reason="Please install a nightly build of torch to run this test.", + ) + @gpu_test(gpu_count=1) + def test_padded_collate_packed_flex(self): + # create_block_mask requires that seq_len be divisible by 128, the default block size. + # see https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L636 + batch = [ + { + "tokens": torch.arange(128, dtype=torch.long), + "labels": torch.arange(128, dtype=torch.long), + "input_pos": torch.arange(128, dtype=torch.long), + "seq_lens": torch.ones(64, dtype=torch.long) * 2, + }, + { + "tokens": torch.arange(128, 256, dtype=torch.long), + "labels": torch.arange(128, 256, dtype=torch.long), + "input_pos": torch.arange(128, 256, dtype=torch.long), + "seq_lens": torch.ones(32, dtype=torch.long) * 4, + }, + ] + collated = padded_collate_packed( + batch=batch, + ) + torch.testing.assert_close( + collated["tokens"], + torch.stack( + [ + torch.arange(128, dtype=torch.long), + torch.arange(128, 256, dtype=torch.long), + ] + ), + ) + torch.testing.assert_close( + collated["labels"], + torch.stack( + [ + torch.arange(128, dtype=torch.long), + torch.arange(128, 256, dtype=torch.long), + ] + ), + ) + torch.testing.assert_close( + collated["input_pos"], + torch.stack( + [ + torch.arange(128, dtype=torch.long), + torch.arange(128, 256, dtype=torch.long), + ] + ), + ) + torch.testing.assert_close( + collated["mask"].to_dense(), + torch.tensor([[[[1]]], [[[1]]]], dtype=torch.int32, device="cuda"), + ) + class TestLeftPadSequence: def test_left_pad_sequence(self): diff --git a/tests/torchtune/datasets/test_packed_dataset.py b/tests/torchtune/datasets/test_packed_dataset.py index 8afb532c0b..208ac333f5 100644 --- a/tests/torchtune/datasets/test_packed_dataset.py +++ b/tests/torchtune/datasets/test_packed_dataset.py @@ -48,34 +48,27 @@ def __len__(self): class TestPackedDataset: - def _get_expected_mask_and_input_pos( + def _get_expected_seq_lens_and_input_pos( self, max_seq_len, sample_size, split_across_pack ): """ - Generate expected integer mask and position ids for given max sequence + Generate expected seq lens and position ids for given max sequence length and sample length """ num_samples, remainder = divmod(max_seq_len, sample_size) + seq_lens = [sample_size] * num_samples if split_across_pack and remainder > 0: num_samples += 1 - mask = torch.block_diag( - *[ - torch.tril(torch.ones(sample_size, sample_size, dtype=torch.bool)) - for i in range(1, num_samples + 1) - ] - ) input_pos = [list(range(sample_size)) for i in range(1, num_samples + 1)] input_pos = list(itertools.chain(*input_pos)) - # Emulate mask and position id padding - if not split_across_pack and remainder > 0: - mask = torch.block_diag( - mask, - torch.eye(remainder, dtype=torch.bool), - ) - input_pos.extend(list(range(sample_size, sample_size + remainder))) + # Emulate seq len and position id padding + if remainder > 0: + if not split_across_pack: + input_pos.extend(list(range(sample_size, sample_size + remainder))) + seq_lens.extend([remainder]) - return mask[:max_seq_len, :max_seq_len], torch.tensor(input_pos[:max_seq_len]) + return torch.tensor(seq_lens), torch.tensor(input_pos[:max_seq_len]) def _calculate_num_packs( self, dataset_size, max_seq_len, sample_size, split_across_pack, max_packs @@ -122,7 +115,6 @@ def test_packed_dataset( assert ( len(packed[0]["tokens"]) == len(packed[0]["labels"]) - == len(packed[0]["mask"]) == len(packed[0]["input_pos"]) ) # Check that samples are packed correctly - very last individual sample @@ -145,10 +137,14 @@ def test_packed_dataset( assert packed[-1]["tokens"][-1].item() == last_index - expected_mask, expected_input_pos = self._get_expected_mask_and_input_pos( + ( + expected_seq_lens, + expected_input_pos, + ) = self._get_expected_seq_lens_and_input_pos( max_seq_len, sample_size, split_across_pack ) - torch.testing.assert_close(packed[0]["mask"], expected_mask) + + torch.testing.assert_close(packed[0]["seq_lens"], expected_seq_lens) torch.testing.assert_close(packed[0]["input_pos"], expected_input_pos) def test_packed_dataset_real_data(self): @@ -162,48 +158,15 @@ def test_packed_dataset_real_data(self): torch.tensor([5, 2, 6, 4, 3, 8, -1, 0, 4, 3]), torch.tensor([4, 3, 2, 5, 7, -1, -100, -100, -100, -100]), ] - expected_mask = [ + expected_seq_lens = [ torch.tensor( - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] + [7, 3], ), torch.tensor( - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - ] + [7, 3], ), torch.tensor( - [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ] + [6, 4], ), ] expected_input_pos = [ @@ -219,16 +182,16 @@ def test_packed_dataset_real_data(self): ) for i in range(len(packed)): - prompt, label, mask, input_pos = ( + prompt, label, seq_lens, input_pos = ( packed[i]["tokens"], packed[i]["labels"], - packed[i]["mask"], + packed[i]["seq_lens"], packed[i]["input_pos"], ) torch.testing.assert_close(prompt, expected_tokenized_prompts[i]) torch.testing.assert_close(label, expected_tokenized_labels[i]) torch.testing.assert_close(input_pos, expected_input_pos[i]) - torch.testing.assert_close(mask, expected_mask[i].to(dtype=torch.bool)) + torch.testing.assert_close(seq_lens, expected_seq_lens[i]) def test_pad_pack(self): padding_idx = -8 @@ -255,6 +218,7 @@ def test_pad_pack(self): padded_input = padded["tokens"] padded_label = padded["labels"] padded_input_pos = padded["input_pos"] + padded_seq_lens = padded["seq_lens"] torch.testing.assert_close( padded_input, torch.tensor([2, 5, padding_idx, padding_idx]) @@ -263,6 +227,7 @@ def test_pad_pack(self): padded_label, torch.tensor([3, 7, ignore_idx, ignore_idx]) ) torch.testing.assert_close(padded_input_pos, torch.tensor([8, 0, 1, 2])) + torch.testing.assert_close(padded_seq_lens, torch.tensor([1, 1, 2])) def test_pack_errors_if_sample_too_long(self): dataset = DummyDataset(8) diff --git a/tests/torchtune/modules/test_attention_utils.py b/tests/torchtune/modules/test_attention_utils.py new file mode 100644 index 0000000000..a646c37117 --- /dev/null +++ b/tests/torchtune/modules/test_attention_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +from unittest import mock + +import pytest +import torch +from tests.test_utils import gpu_test + +from torchtune.modules.attention_utils import ( + _get_document_ids_from_seq_lens, + _sdpa_or_flex_attention, + _SUPPORTS_FLEX_ATTENTION, + create_block_causal_mask, + packed_block_causal_mask, +) + + +class TestBlockCausalMask: + @pytest.fixture + def seq_lens(self): + return [torch.tensor([2, 3, 1]), torch.tensor([2, 2, 2, 0])] + + def test_get_document_ids_from_seq_lens(self, seq_lens): + actual = _get_document_ids_from_seq_lens(seq_lens) + expected = torch.tensor([[0, 0, 1, 1, 1, 2], [0, 0, 1, 1, 2, 2]]) + torch.testing.assert_close(actual, expected) + + def test_create_block_causal_mask(self, seq_lens): + actual = create_block_causal_mask(seq_lens) + expected = torch.tensor( + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 1], + ], + ], + dtype=torch.bool, + ) + torch.testing.assert_close(actual, expected) + + @mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False) + def test_packed_block_causal_mask_sdpa(self, seq_lens): + actual = packed_block_causal_mask(seq_lens) + expected = torch.tensor( + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 1], + ], + ], + dtype=torch.bool, + ) + torch.testing.assert_close(actual, expected) + + @pytest.mark.skipif( + not _SUPPORTS_FLEX_ATTENTION, + reason="Please install a nightly build of torch (>=2.5.0) to run this test.", + ) + @gpu_test(gpu_count=1) + def test_packed_block_causal_mask_flex(self): + # create_block_mask requires that seq_len be divisible by 128, the default block size. + # see https://github.com/pytorch/pytorch/blob/3bf6be457d40034aa4b603b7ea1b8977051221ed/torch/nn/attention/flex_attention.py#L792 # noqa + actual = packed_block_causal_mask( + [torch.tensor([64, 64]), torch.tensor([64, 64])] + ) + expected = torch.tensor([[[[1]]], [[[1]]]], device="cuda:0", dtype=torch.int32) + torch.testing.assert_close(actual.to_dense(), expected) + + +class TestSDPAOrFlexAttention: + @pytest.mark.skipif( + not _SUPPORTS_FLEX_ATTENTION, + reason="Please install a nightly build of torch (>=2.5.0) to run this test.", + ) + @mock.patch("torchtune.modules.attention_utils.torch.compile") + @mock.patch( + "torchtune.modules.attention_utils.nn.functional.scaled_dot_product_attention" + ) + def test_flex_attention(self, mock_sdpa, mock_compile): + mock_flex = mock.MagicMock() + mock_compile.return_value = mock_flex + q = torch.ones(2, 3, 4) + k = torch.ones(2, 3, 4) + v = torch.ones(2, 3, 4) + attn_mask = torch.ones(2, 3, 4) + dropout_p = 0.0 + is_causal = False + + # Pretend that mask is actually a BlockMask + with mock.patch( + "torchtune.modules.attention_utils.isinstance", return_value=True + ): + _attention_call = _sdpa_or_flex_attention() + _ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal) + mock_sdpa.assert_not_called() + mock_flex.assert_called_with(q, k, v, block_mask=attn_mask) + # If mask is not a BlockMask, then we should call SDPA + _attention_call = _sdpa_or_flex_attention() + _ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal) + mock_sdpa.assert_called_once() + assert mock_flex.call_count == 1 + + @mock.patch("torchtune.modules.attention_utils._SUPPORTS_FLEX_ATTENTION", False) + @mock.patch("torchtune.modules.attention_utils.torch.compile") + @mock.patch( + "torchtune.modules.attention_utils.nn.functional.scaled_dot_product_attention" + ) + def test_sdpa_attention(self, mock_sdpa, mock_compile): + mock_flex = mock.MagicMock() + mock_compile.return_value = mock_flex + q = torch.ones(2, 3, 4) + k = torch.ones(2, 3, 4) + v = torch.ones(2, 3, 4) + attn_mask = torch.ones(2, 3, 4) + dropout_p = 0.0 + is_causal = False + _attention_call = _sdpa_or_flex_attention() + _ = _attention_call(q, k, v, attn_mask, dropout_p, is_causal) + mock_sdpa.assert_called_once() + mock_flex.assert_not_called() diff --git a/tests/torchtune/utils/test_logging.py b/tests/torchtune/utils/test_logging.py index fd5054d53d..d1dee6bcfe 100644 --- a/tests/torchtune/utils/test_logging.py +++ b/tests/torchtune/utils/test_logging.py @@ -4,8 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +from io import StringIO +from unittest import mock + import pytest -from torchtune.utils.logging import deprecated +from torchtune.utils.logging import deprecated, log_rank_zero def test_deprecated(): @@ -33,3 +37,38 @@ def dummy_func(): match="dummy_func is deprecated and will be removed in future versions. Please use `totally_awesome_func` instead.", ): dummy_func() + + +def test_log_rank_zero(capsys): + # Create a logger and add a StreamHandler to it so we can + # assert on logged strings + logger = logging.getLogger(__name__) + logger.setLevel("DEBUG") + stream = StringIO() + handler = logging.StreamHandler(stream) + logger.addHandler(handler) + + with mock.patch( + "torchtune.utils.logging.dist.is_available", return_value=True + ), mock.patch("torchtune.utils.logging.dist.is_initialized", return_value=True): + # Make sure rank 0 logs as expected + with mock.patch( + "torchtune.utils.logging.dist.get_rank", + return_value=0, + ): + log_rank_zero(logger, "this is a test", level=logging.DEBUG) + output = stream.getvalue().strip() + assert "this is a test" in output + + # Clear the stream + stream.truncate(0) + stream.seek(0) + + # Make sure all other ranks do not log anything + with mock.patch( + "torchtune.utils.logging.dist.get_rank", + return_value=1, + ): + log_rank_zero(logger, "this is a test", level=logging.DEBUG) + output = stream.getvalue().strip() + assert not output diff --git a/torchtune/config/_utils.py b/torchtune/config/_utils.py index 44ee0e7507..5c3806228c 100644 --- a/torchtune/config/_utils.py +++ b/torchtune/config/_utils.py @@ -10,10 +10,9 @@ from typing import Any, Dict, List, Union from omegaconf import DictConfig, OmegaConf -from torch import distributed as dist from torchtune.config._errors import InstantiationError -from torchtune.utils.logging import get_logger +from torchtune.utils.logging import get_logger, log_rank_zero def log_config(recipe_name: str, cfg: DictConfig) -> None: @@ -24,14 +23,11 @@ def log_config(recipe_name: str, cfg: DictConfig) -> None: recipe_name (str): name of the recipe to display cfg (DictConfig): parsed config object """ - # Log the config only on rank 0 - rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 - if rank != 0: - return - logger = get_logger("DEBUG") cfg_str = OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True) - logger.info(msg=f"Running {recipe_name} with resolved config:\n\n{cfg_str}") + log_rank_zero( + logger=logger, msg=f"Running {recipe_name} with resolved config:\n\n{cfg_str}" + ) def _has_component(node: Union[Dict[str, Any], DictConfig]) -> bool: diff --git a/torchtune/data/__init__.py b/torchtune/data/__init__.py index 35956c57bc..b9955d0caa 100644 --- a/torchtune/data/__init__.py +++ b/torchtune/data/__init__.py @@ -9,6 +9,7 @@ left_pad_sequence, padded_collate, padded_collate_dpo, + padded_collate_packed, padded_collate_sft, ) from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX @@ -58,5 +59,6 @@ "padded_collate_dpo", "left_pad_sequence", "padded_collate", + "padded_collate_packed", "load_image", ] diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 005e5b9755..dcd6672d24 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -9,6 +9,8 @@ import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets._packed import PACK_TYPE +from torchtune.modules.attention_utils import packed_block_causal_mask def left_pad_sequence( @@ -214,6 +216,67 @@ def padded_collate_sft( return {"tokens": input_ids.long(), "labels": labels.long()} +def padded_collate_packed( + batch: List[PACK_TYPE], +) -> Dict[str, torch.Tensor]: + """Collate packed sequences into a batch. Only convert the seq lens into + a block mask for use with attention. Tokens, labels, and input_pos are + already padded to the same length within :class:`~torchtune.datasets.PackedDataset`. + + Args: + batch (List[PACK_TYPE]): A list of pack dictionaries containing the following keys: + - tokens: input token ids + - labels: label token ids + - input_pos: relative position ids for each sequence in pack + - seq_lens: lengths of each sample within the pack + + Returns: + Dict[str, torch.Tensor]: Collated input, label, input_pos, mask tensors. + + Example: + >>> token_pairs = [ + >>> {"tokens": [1, 2, 3, 4, 5, 6], "labels": [7, 8, 9, 10, 11, 12], + >>> "input_pos": [0, 1, 2, 0, 1, 0], "seq_lens": [3, 2, 1]}, + >>> {"tokens": [13, 14, 15, 16, 17, 18], "labels": [19, 20, 21, 22, 23, 24], + >>> "input_pos": [0, 1, 0, 1, 0, 1], "seq_lens": [2, 2, 2]}, + >>> ] + >>> collated = padded_collate_packed( + >>> batch=token_pairs, + >>> device=device, + >>> ) + >>> collated["mask"] + >>> tensor([ + >>> [[1, 0, 0, 0, 0, 0], + >>> [1, 1, 0, 0, 0, 0], + >>> [1, 1, 1, 0, 0, 0], + >>> [0, 0, 0, 1, 0, 0], + >>> [0, 0, 0, 1, 1, 0], + >>> [0, 0, 0, 0, 0, 1]], + >>> [[1, 0, 0, 0, 0, 0], + >>> [1, 1, 0, 0, 0, 0], + >>> [0, 0, 1, 0, 0, 0], + >>> [0, 0, 1, 1, 0, 0], + >>> [0, 0, 0, 0, 1, 0], + >>> [0, 0, 0, 0, 1, 1]]) + """ + + tokens = torch.stack([x["tokens"] for x in batch]) + labels = torch.stack([x["labels"] for x in batch]) + input_pos = torch.stack([x["input_pos"] for x in batch]) + seq_lens = [x["seq_lens"] for x in batch] + + block_mask = packed_block_causal_mask( + seq_lens=seq_lens, + ) + + return { + "tokens": tokens, + "labels": labels, + "input_pos": input_pos, + "mask": block_mask, + } + + def padded_collate_dpo( batch: List[Dict[str, List[int]]], padding_idx: int = 0, diff --git a/torchtune/data/_common.py b/torchtune/data/_common.py index 3f8c4607d7..d6c768d70b 100644 --- a/torchtune/data/_common.py +++ b/torchtune/data/_common.py @@ -3,5 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, List, Union + +import torch CROSS_ENTROPY_IGNORE_IDX = -100 +PACK_TYPE = Dict[str, Union[torch.Tensor, List[int]]] diff --git a/torchtune/datasets/_packed.py b/torchtune/datasets/_packed.py index 5f293455d8..8cee72eed9 100644 --- a/torchtune/datasets/_packed.py +++ b/torchtune/datasets/_packed.py @@ -4,17 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import torch from torch.nn import functional as F from torch.utils.data import Dataset -from torchtune.data import CROSS_ENTROPY_IGNORE_IDX +from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX, PACK_TYPE from tqdm import tqdm -PACK_TYPE = Dict[str, Union[torch.Tensor, List[int]]] - class PackedDataset(Dataset): """ @@ -170,7 +168,8 @@ def _split_and_add_pack(self, current_pack: PACK_TYPE) -> PACK_TYPE: if self.split_across_pack: boundary = self.max_seq_len # The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len`` - seq_len_padding = [self.max_seq_len - sum(current_pack["seq_lens"][:-1])] + leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1]) + seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else [] else: boundary = self.previous_sample_boundary # If we aren't splitting across packs, we leave out the last sample b/c @@ -209,22 +208,21 @@ def _add_pack(self, pack: PACK_TYPE) -> None: self.packs.append(pack) def _convert_to_tensors(self, pack: PACK_TYPE) -> PACK_TYPE: - """Converts a pack into tensors. Pack comes in as a dict of lists and is converted to tensors. - The only key that does not get converted is ``seq_lens``. - """ + """Converts a pack into tensors. Pack comes in as a dict of lists and is converted to tensors.""" return { - "tokens": torch.tensor(pack["tokens"]), - "labels": torch.tensor(pack["labels"]), - "input_pos": torch.tensor(pack["input_pos"]), - "seq_lens": pack["seq_lens"], + "tokens": torch.tensor(pack["tokens"], dtype=torch.long), + "labels": torch.tensor(pack["labels"], dtype=torch.long), + "input_pos": torch.tensor(pack["input_pos"], dtype=torch.long), + "seq_lens": torch.tensor(pack["seq_lens"], dtype=torch.long), } def _pad_pack(self, pack: PACK_TYPE, padding_idx: int) -> PACK_TYPE: """Pads a pack to ``self.max_seq_len``.""" # Pad tokens + num_padding_tokens = self.max_seq_len - len(pack["tokens"]) padded_tokens = F.pad( pack["tokens"], - (0, self.max_seq_len - len(pack["tokens"])), + (0, num_padding_tokens), value=padding_idx, ) @@ -235,6 +233,13 @@ def _pad_pack(self, pack: PACK_TYPE, padding_idx: int) -> PACK_TYPE: value=CROSS_ENTROPY_IGNORE_IDX, ) + # Add padding tokens as a last seq len to ensure sum is max_seq_len + padded_seq_lens = ( + torch.cat([pack["seq_lens"], torch.tensor([num_padding_tokens])]) + if num_padding_tokens > 0 + else pack["seq_lens"] + ) + # Pad input_pos continuing the sequence from last value # in input_pos # e.g. [0 1 2] -> [0 1 2 3 4 5] for self.max_seq_len = 6 @@ -250,44 +255,11 @@ def _pad_pack(self, pack: PACK_TYPE, padding_idx: int) -> PACK_TYPE: "tokens": padded_tokens, "labels": padded_labels, "input_pos": padded_input_pos, - "seq_lens": pack["seq_lens"], # seq_len is untouched + "seq_lens": padded_seq_lens, } def __len__(self) -> int: return len(self.packs) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - """Constructs the attention mask on-the-fly and returns whole sample.""" - current_pack = self.packs[idx] - - num_samples_in_pack = len(current_pack["seq_lens"]) - total_seq_len = 0 - - block_attn_masks = [] - - for i, seq_len in enumerate(current_pack["seq_lens"]): - total_seq_len += seq_len - - # Append lower triangular matrix for causal mask - block_attn_masks.append( - torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) - ) - - # If we're at the last sample and the total seq len is less than the max seq len, - # we need to pad with identity matrix for the remainder - if i == num_samples_in_pack - 1 and total_seq_len < self.max_seq_len: - block_attn_masks.append( - torch.eye( - self.max_seq_len - total_seq_len, - self.max_seq_len - total_seq_len, - dtype=torch.bool, - ) - ) - - return { - "tokens": current_pack["tokens"], - "labels": current_pack["labels"], - "input_pos": current_pack["input_pos"], - # Assemble the mask into a block causal matrix - "mask": torch.block_diag(*block_attn_masks), - } + return self.packs[idx] diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 180d307765..49ea3ed764 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import List, Literal, Optional +from typing import List, Optional from torch import nn @@ -15,7 +15,6 @@ MultiHeadAttention, FeedForward, FrozenNF4Linear, - KVCache, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 1c508941f7..695a7c6f1f 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from .attention import MultiHeadAttention # noqa +from .attention_utils import create_block_causal_mask, packed_block_causal_mask from .common_utils import reparametrize_as_dtype_state_dict_post_hook from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa @@ -38,4 +39,6 @@ "TransformerSelfAttentionLayer", "TransformerCrossAttentionLayer", "reparametrize_as_dtype_state_dict_post_hook", + "create_block_causal_mask", + "packed_block_causal_mask", ] diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 7731635ffd..a3b6655d0b 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -9,6 +9,7 @@ import torch from torch import nn +from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache logger = logging.getLogger(__name__) @@ -136,6 +137,9 @@ def __init__( self.k_norm = k_norm self.pos_embeddings = pos_embeddings + # Use flex attention if supported and we are sample packing + self._attention_call = _sdpa_or_flex_attention() + def setup_cache(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. If called after kv_cache is already setup, this will be skipped. @@ -171,7 +175,7 @@ def forward( x: torch.Tensor, y: Optional[torch.Tensor] = None, *, - mask: Optional[torch.Tensor] = None, + mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -179,12 +183,16 @@ def forward( x (torch.Tensor): input tensor with shape [b x s_x x d] for the query y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input for k and v. For self attention, x=y. Optional only with kv_cache enabled. - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape [batch_size x seq_length x seq_length]. This is applied after - the query-key multiplication and before the softmax. A value of True in row i - and column j means token i attends to token j. A value of False means token i - does not attend to token j. If no mask is specified, a causal mask - is used by default. Default is None. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either a boolean tensor with shape [b x s x s] or a + :class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value + of True in row i and column j means token i attends to token j. A value of False means + token i does not attend to token j. If no mask is specified, a causal mask + is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed + for document masking in a packed sequence via `create_block_mask + `_, we use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention. + Default is None. input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. @@ -284,16 +292,11 @@ def forward( if self.kv_cache is not None: k, v = self.kv_cache.update(input_pos, k, v) - # shape: [b, 1, s, s] - if mask is not None: - mask = mask[:, None, :, :] - - # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ - output = nn.functional.scaled_dot_product_attention( + output = self._attention_call( q, k, v, - attn_mask=mask, + mask=mask, dropout_p=self.attn_dropout, is_causal=self.kv_cache is None and mask is None and self.is_causal, ) diff --git a/torchtune/modules/attention_utils.py b/torchtune/modules/attention_utils.py new file mode 100644 index 0000000000..fc01085eaa --- /dev/null +++ b/torchtune/modules/attention_utils.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Callable, List, Optional, Union + +import torch + +from torch import nn +from torchtune.utils._version import torch_version_ge +from torchtune.utils.logging import get_logger, log_once + +_log: logging.Logger = get_logger() + +# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above +_SUPPORTS_FLEX_ATTENTION = ( + torch_version_ge("2.5.0") + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (7, 5) +) + +if _SUPPORTS_FLEX_ATTENTION: + from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask as create_block_causal_mask_flex, + flex_attention, + ) + + flex_attention_compiled = torch.compile(flex_attention, dynamic=False) + + # We cannot do nested compile, but flex attention only has perf benefits + # when compiled. To insulate it from the compiler, we wrap it with + # compiler.disable so that it can be used regardless of whether the model + # is compiled or not, and flex attention always remains compiled. + @torch.compiler.disable(recursive=False) + def compile_friendly_flex_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_mask: BlockMask, + ) -> torch.Tensor: + return flex_attention_compiled(q, k, v, block_mask=block_mask) + + _MaskType = Union[torch.Tensor, BlockMask] +else: + _MaskType = torch.Tensor + + +def _get_document_ids_from_seq_lens( + seq_lens: List[torch.Tensor], +) -> torch.Tensor: + """ + Convert a batch tensor of seq lens into integer IDs denoting sample ownership. + For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2]. + + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + + Returns: + Tensor: Document IDs of shape (batch_size, max_seq_len). + """ + batch_size = len(seq_lens) + batch_document_ids = [] + for sample_idx in range(batch_size): + # We assume seq lens sum to max seq lens, so document_ids should be of + # shape (max_seq_len, ) + document_ids = torch.cat( + [ + torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device) + for i, seq_len in enumerate(seq_lens[sample_idx]) + ] + ) + batch_document_ids.append(document_ids) + batch_document_ids = torch.stack(batch_document_ids) + return batch_document_ids + + +def create_block_causal_mask(seq_lens: List[torch.Tensor]) -> torch.Tensor: + """ + Given a batch tensor of seq lens defining the lengths of samples in each pack, + Construct a 2D block causal mask for each pack in the batch. For example, if + a single sample's seq_lens is [3, 2, 1], the mask would be:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + + + Returns: + Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len). + """ + batch_block_attn_masks = [] + batch_size = len(seq_lens) + for sample_idx in range(batch_size): + block_attn_masks = [ + torch.tril( + torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device) + ) + for i, seq_len in enumerate(seq_lens[sample_idx]) + ] + + batch_block_attn_masks.append(torch.block_diag(*block_attn_masks)) + return torch.stack(batch_block_attn_masks) + + +def packed_block_causal_mask( + seq_lens: List[torch.Tensor], +) -> _MaskType: + """ + Create a block causal document mask for a batch of packed sequences. If on + torch version >= 2.5.0, this is done by creating a mask_mod function with the + block causal logic and passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full block causal + mask. If on an older version, a standard 2D block causal mask is created and returned. + + Args: + seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch, + shape (batch_size, n), where n is the max number of sequences in a pack and can vary + across packs. + + Returns: + _MaskType: BlockMask or Tensor if torch version < 2.5.0. + """ + if _SUPPORTS_FLEX_ATTENTION: + document_ids = _get_document_ids_from_seq_lens(seq_lens) + batch_size, max_seq_len = document_ids.shape + document_ids = document_ids.to("cuda") + + # Instead of passing a tensor mask, flex attention requires a mask_mod function + # that determines which elements of QK^T should be included in the attention + # computation prior to the softmax. For sample packing, we need both the + # logic for both causal mask and document mask. See PyTorch's official + # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods + def mask_mod(b, h, q_idx, kv_idx): + """ + Defines the logic of a block causal mask by combining both a standard causal mask + and a block diagonal document mask. + + See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` + for an illustration. + """ + causal_mask = q_idx >= kv_idx + document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx] + return causal_mask & document_mask + + return create_block_causal_mask_flex( + mask_mod, + batch_size, + None, + max_seq_len, + max_seq_len, + device="cuda", + ) + else: + return create_block_causal_mask(seq_lens=seq_lens) + + +def _sdpa_or_flex_attention() -> Callable: + """ + Helper function to decide when to call flex attention or SDPA. It will use + flex attention if ALL of the following conditions are met, otherwise it will + default to SDPA: + - torch version >= 2.5.0 + - we are sample packing, therefore mask is a BlockMask + - torch.cuda.get_device_capability() >= (7, 5) + """ + + if _SUPPORTS_FLEX_ATTENTION: + + def _attention_call( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[_MaskType], + dropout_p: float, + is_causal: bool, + ) -> torch.Tensor: + + # Flex attention uses the BlockMask + # (https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L168) + # instead of a traditional boolean tensor mask. If this is passed in, + # we assume the user wants to use flex attention instead of traditional SDPA. + # This will use flash attention under the hood with support for custom masks. + # Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset) + if isinstance(mask, BlockMask): + log_once( + _log, + "Using flex attention for attention computation since a BlockMask was passed in.", + level=logging.DEBUG, + ) + return compile_friendly_flex_attention( + q, + k, + v, + block_mask=mask, + ) + # If mask is a standard boolean tensor or None, then use SDPA + else: + # shape: [b, 1, s, s] + if mask is not None: + mask = mask[:, None, :, :] + + # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ + return nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + + else: + + def _attention_call( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[_MaskType], + dropout_p: float, + is_causal: bool, + ) -> torch.Tensor: + # shape: [b, 1, s, s] + if mask is not None: + mask = mask[:, None, :, :] + + # Flash attention from https://pytorch.org/blog/accelerating-large-language-models/ + return nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + + return _attention_call diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 9f3f01a585..65e511e5ea 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -11,6 +11,8 @@ from torch import nn from torchtune.modules import MultiHeadAttention +from torchtune.modules.attention_utils import _MaskType + class TransformerSelfAttentionLayer(nn.Module): """Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer. @@ -64,7 +66,7 @@ def forward( self, x: torch.Tensor, *, - mask: Optional[torch.Tensor] = None, + mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, **kwargs: Dict, ) -> torch.Tensor: @@ -72,12 +74,16 @@ def forward( Args: x (torch.Tensor): input tensor with shape [batch_size x seq_length x embed_dim] - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape [batch_size x seq_length x seq_length]. This is applied after - the query-key multiplication and before the softmax. A value of True in row i - and column j means token i attends to token j. A value of False means token i - does not attend to token j. If no mask is specified, a causal mask - is used by default. Default is None. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either a boolean tensor with shape [b x s x s] or a + :class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value + of True in row i and column j means token i attends to token j. A value of False means + token i does not attend to token j. If no mask is specified, a causal mask + is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed + for document masking in a packed sequence via `create_block_mask + `_, we use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention. + Default is None. input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. @@ -406,7 +412,7 @@ def forward( self, tokens: torch.Tensor, *, - mask: Optional[torch.Tensor] = None, + mask: Optional[_MaskType] = None, encoder_input: Optional[torch.Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, @@ -414,11 +420,16 @@ def forward( """ Args: tokens (torch.Tensor): input tensor with shape [b x s] - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape [b x s x s]. This is applied after the query-key multiplication and - before the softmax. A value of True in row i and column j means token i attends - to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either a boolean tensor with shape [b x s x s] or a + :class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value + of True in row i and column j means token i attends to token j. A value of False means + token i does not attend to token j. If no mask is specified, a causal mask + is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed + for document masking in a packed sequence via `create_block_mask + `_, we use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention. + Default is None. encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend @@ -637,7 +648,7 @@ def forward( self, tokens: torch.Tensor, *, - mask: Optional[torch.Tensor] = None, + mask: Optional[_MaskType] = None, encoder_input: Optional[torch.Tensor] = None, encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, @@ -645,11 +656,16 @@ def forward( """ Args: tokens (torch.Tensor): input tensor with shape [b x s] - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape [b x s x s]. This is applied after the query-key multiplication and - before the softmax. A value of True in row i and column j means token i attends - to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either a boolean tensor with shape [b x s x s] or a + :class:`~torch.nn.attention.flex_attention.BlockMask`. If a boolean tensor, a value + of True in row i and column j means token i attends to token j. A value of False means + token i does not attend to token j. If no mask is specified, a causal mask + is used by default. If a :class:`~torch.nn.attention.flex_attention.BlockMask` is passed + for document masking in a packed sequence via `create_block_mask + `_, we use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention. + Default is None. encoder_input (Optional[torch.Tensor]): Optional input embeds from the encoder. Shape [b x s_e x d_e] encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend diff --git a/torchtune/utils/logging.py b/torchtune/utils/logging.py index c587042ad4..8272bda94f 100644 --- a/torchtune/utils/logging.py +++ b/torchtune/utils/logging.py @@ -9,6 +9,8 @@ from functools import lru_cache, wraps from typing import Callable, Optional, TypeVar +from torch import distributed as dist + T = TypeVar("T", bound=type) @@ -31,6 +33,22 @@ def get_logger(level: Optional[str] = None) -> logging.Logger: return logger +@lru_cache(None) +def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: + """ + Logs a message only once. LRU cache is used to ensure a specific message is + logged only once, similar to how :func:`~warnings.warn` works when the ``once`` + rule is set via command-line or environment variable. + + Args: + logger (logging.Logger): The logger. + msg (str): The warning message. + level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. + Defaults to ``logging.INFO``. + """ + log_rank_zero(logger=logger, msg=msg, level=level) + + def deprecated(msg: str = "") -> Callable[[T], T]: """ Decorator to mark an object as deprecated and print additional message. @@ -60,3 +78,19 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: + """ + Logs a message only on rank zero. + + Args: + logger (logging.Logger): The logger. + msg (str): The warning message. + level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. + Defaults to ``logging.INFO``. + """ + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if rank != 0: + return + logger.log(level, msg)