From 99c89da4f36b14c536d654eae7b16987024b5f07 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Mon, 30 Dec 2024 06:56:58 +0200 Subject: [PATCH 1/2] Fix `_prepare_cos_sin` flow in RoPE for LoRA + long-contexts --- tests/lora/conftest.py | 1 + vllm/lora/punica_wrapper/utils.py | 24 +++++++++++++++---- .../model_executor/layers/rotary_embedding.py | 2 +- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 9f8de7cb74cb9..7f0dccc38ca52 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -64,6 +64,7 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): + import habana_frameworks.torch.hpu # noqa: F401 temp_file = tempfile.mkstemp()[1] backend_type = "hccl" if current_platform.is_hpu() else "nccl" init_distributed_environment( diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 7360c8c09e3ac..4504e19b20816 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -2,6 +2,8 @@ import torch +from vllm.platforms import current_platform + if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping @@ -86,10 +88,14 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + if current_platform.is_hpu(): + long_lora_offsets_list: List[int] = [] + else: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -102,10 +108,18 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: - assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset + if current_platform.is_hpu(): + long_lora_offsets_list.append(lora_offset) + else: + assert long_lora_offsets is not None + long_lora_offsets[i] = lora_offset + + if long_lora_context and current_platform.is_hpu(): + long_lora_offsets = torch.tensor(long_lora_offsets_list, + device=device, + dtype=torch.long) indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 6344c3d39eb7e..86ede94abc5e9 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -232,7 +232,7 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - if self.sin is None: + if not hasattr(self, "sin") or self.sin is None or offsets is not None: self.prepare_cos_sin(positions, offsets) num_tokens = positions.shape[0] * positions.shape[1] # HPU RoPE kernel requires hidden dimension for cos and sin to be equal From c52c37efe6017b079f7d4c2eca5e6c2697067d8f Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Thu, 2 Jan 2025 08:10:38 +0200 Subject: [PATCH 2/2] Handle cos-sin cache in every forward in case of long-context + LoRA --- vllm/model_executor/layers/rotary_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 86ede94abc5e9..a601189788441 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -232,7 +232,10 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - if not hasattr(self, "sin") or self.sin is None or offsets is not None: + + # Prepare cos-sin caches for long-context + LoRA with offsets for every + # forward, since the offset information wasn't available previously + if hasattr(self, "scaling_factors") or self.sin is None: self.prepare_cos_sin(positions, offsets) num_tokens = positions.shape[0] * positions.shape[1] # HPU RoPE kernel requires hidden dimension for cos and sin to be equal