Skip to content

Commit

Permalink
PP seems to be working.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaksymczuk committed Sep 18, 2024
1 parent 65c40eb commit 963fefd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 0 additions & 2 deletions vllm/executor/distributed_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def execute_model(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
print("\n\n\n EXECUTING MODEL \n\n\n")
if log_graph_compilation or log_cpu_fallbacks:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
Expand Down Expand Up @@ -115,7 +114,6 @@ def execute_model(
cpu_fallback_ctx as cpu_fallback_local_metric:
# output = self.driver_worker.execute_model(execute_model_req)
if self.parallel_worker_tasks is None:
print("\n\n\n CHECK CHECK CHECK \n\n\n")
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_tensor_parallel_workers_only=True,
Expand Down
12 changes: 7 additions & 5 deletions vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (HabanaMemoryProfiler, get_distributed_init_method,
get_ip, get_open_port, make_async)
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)

def create_worker(worker_module_name, worker_class_name, **kwargs):
wrapper = WorkerWrapperBase(
# worker_module_name="vllm.worker.habana_worker",
# worker_class_name="HabanaWorker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
Expand Down Expand Up @@ -65,12 +67,12 @@ def _get_create_worker_kwargs(
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.speculative_config is None:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
worker_kwargs.update(worker_module_name="vllm.worker.habana_worker",
worker_class_name="HabanaWorker",)
else:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
worker_module_name="vllm.worker.habana_worker",
worker_class_name="HabanaWorker",)
return worker_kwargs

def _create_worker(self,
Expand Down

0 comments on commit 963fefd

Please sign in to comment.