Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix long contexts in LoRA #624

Open
wants to merge 4 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):

@pytest.fixture
def dist_init():
import habana_frameworks.torch.hpu # noqa: F401
temp_file = tempfile.mkstemp()[1]
backend_type = "hccl" if current_platform.is_hpu() else "nccl"
init_distributed_environment(
Expand Down
24 changes: 19 additions & 5 deletions vllm/lora/punica_wrapper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from vllm.platforms import current_platform

if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
Expand Down Expand Up @@ -86,10 +88,14 @@ def convert_mapping(
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None

if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
if current_platform.is_hpu():
long_lora_offsets_list: List[int] = []
else:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
Expand All @@ -102,10 +108,18 @@ def convert_mapping(
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
if current_platform.is_hpu():
long_lora_offsets_list.append(lora_offset)
else:
assert long_lora_offsets is not None
long_lora_offsets[i] = lora_offset

if long_lora_context and current_platform.is_hpu():
long_lora_offsets = torch.tensor(long_lora_offsets_list,
device=device,
dtype=torch.long)

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def forward_hpu(
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
if self.sin is None:
if not hasattr(self, "sin") or self.sin is None or offsets is not None:
self.prepare_cos_sin(positions, offsets)
num_tokens = positions.shape[0] * positions.shape[1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
Expand Down
15 changes: 10 additions & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm.distributed.parallel_state import get_world_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora, LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
Expand Down Expand Up @@ -169,7 +169,7 @@ def forward_hook(module, args, output):
modify_decoder_layer(child_module, suffix, n, counter)


def get_names_for_rope(model: torch.nn.Module):
def get_names_for_rope(model: torch.nn.Module, lora_config):
"""Dynamically get layer names needed for cos and sin preparation for rope.

Every model can have a different naming convention for it's layers.
Expand All @@ -194,7 +194,11 @@ def get_child(parent, suffix, is_list=False):
attn_name, attn_module = get_child(layers_module,
"Attention",
is_list=True)
rope_name, _ = get_child(attn_module, "RotaryEmbedding")
if lora_config and lora_config.long_lora_scaling_factors:
rope_class_name = "LinearScalingRotaryEmbeddingWithLora"
else:
rope_class_name = "RotaryEmbedding"
rope_name, _ = get_child(attn_module, rope_class_name)

if rope_name is not None:
return {
Expand Down Expand Up @@ -363,7 +367,8 @@ def _prepare_cos_sin(self, positions):
attention_layer = getattr(first_model_layer, attn_name)
rope = getattr(attention_layer, rope_name)

rope.prepare_cos_sin(positions)
if not isinstance(rope, LinearScalingRotaryEmbeddingWithLora):
rope.prepare_cos_sin(positions)

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
Expand Down Expand Up @@ -744,7 +749,7 @@ def load_model(self) -> None:
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
hidden_layer_markstep_interval)
names_for_rope = get_names_for_rope(self.model)
names_for_rope = get_names_for_rope(self.model, self.lora_config)
torch.hpu.synchronize()

with HabanaMemoryProfiler() as m_wrap:
Expand Down
Loading