Skip to content

Commit

Permalink
Fix LoRA tests (#696)
Browse files Browse the repository at this point in the history
This PR updates `test/lora/utils.py` based on latest rebase.
  • Loading branch information
SanjuCSudhakaran authored Jan 20, 2025
1 parent 018ce62 commit b10992b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 3 additions & 3 deletions tests/lora/test_lora_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank,
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora(m, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
manager = DummyLoRAManager(device="hpu")

module_name = "module"
weight = torch.rand([m, n], device="hpu", dtype=dtype)
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
if m // 2 not in TENSOR_SIZES:
pytest.skip("m//2 must be in TENSOR_SIZES")

manager = DummyLoRAManager()
manager = DummyLoRAManager(device="hpu")

module_name = "module"
weight = torch.rand([m // 2, n], device="hpu", dtype=dtype)
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
@pytest.mark.parametrize("rank", RANKS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
manager = DummyLoRAManager()
manager = DummyLoRAManager(device="hpu")

module_name = "module"
weight_q = torch.empty(qkv[0], n, device="hpu", dtype=dtype)
Expand Down
9 changes: 4 additions & 5 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch

from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.utils import get_device


class DummyLoRAManager:
Expand Down Expand Up @@ -32,10 +31,10 @@ def init_random_lora(
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
dtype=weight.dtype,
device=get_device()),
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
dtype=weight.dtype,
device=get_device()),
device=self._device),
)
if generate_embeddings_tensor:
lora.embeddings_tensor = torch.rand(
Expand All @@ -61,8 +60,8 @@ def init_lora(
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device=get_device()),
lora_b=torch.rand([rank, output_dim], device=get_device()),
lora_a=torch.rand([input_dim, rank], device=self._device),
lora_b=torch.rand([rank, output_dim], device=self._device),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)
Expand Down

0 comments on commit b10992b

Please sign in to comment.