Skip to content

Commit

Permalink
Add multiprocessing HPU executor (HabanaAI#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel authored Dec 6, 2024
1 parent 6a4f673 commit e0e47ed
Showing 4 changed files with 60 additions and 8 deletions.
7 changes: 0 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -989,13 +989,6 @@ def __init__(
raise ValueError(
"TPU backend only supports Ray for distributed inference.")

if current_platform.is_hpu() and self.world_size > 1:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
if self.distributed_executor_backend != "ray":
raise ValueError(
"HPU backend only supports Ray for distributed inference.")

if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
4 changes: 4 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -630,6 +630,10 @@ def _get_executor_cls(
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "mp":
from vllm.executor.multiproc_hpu_executor import (
MultiprocessingHPUExecutorAsync)
executor_class = MultiprocessingHPUExecutorAsync
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
6 changes: 5 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -529,7 +529,11 @@ def _get_executor_cls(cls,
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
if distributed_executor_backend == "mp":
from vllm.executor.multiproc_hpu_executor import (
MultiprocessingHPUExecutor)
executor_class = MultiprocessingHPUExecutor
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutor
executor_class = RayHPUExecutor
51 changes: 51 additions & 0 deletions vllm/executor/multiproc_hpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Callable, Optional, Tuple, Type

import habana_frameworks.torch # noqa: F401
import torch

from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
from vllm.logger import init_logger
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase

logger = init_logger(__name__)


class MultiprocessingHPUExecutor(MultiprocessingGPUExecutor):
"""Python multiprocessing-based multi-HPU executor"""

def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.scheduler_config.is_multi_step:
module_name = "vllm.worker.multi_step_hpu_worker"
class_name = "MultiStepHPUWorker"
elif self.speculative_config is not None:
module_name = "vllm.spec_decode.spec_decode_worker"
class_name = "create_spec_worker"
else:
module_name = "vllm.worker.hpu_worker"
class_name = "HPUWorker"
return (module_name, class_name, worker_class_fn)

def _check_executor_parameters(self):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size

hpu_device_count = torch.hpu.device_count()
assert tensor_parallel_size <= hpu_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local hpu count ({hpu_device_count})")

assert world_size <= hpu_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local hpu count ({hpu_device_count})")


class MultiprocessingHPUExecutorAsync(MultiprocessingHPUExecutor,
MultiprocessingGPUExecutorAsync):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)

0 comments on commit e0e47ed

Please sign in to comment.