diff --git a/rl/llm/engines.py b/rl/llm/engines.py index b647f17..f8c47a6 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -530,6 +530,16 @@ def _get_vllm_engine( ) engine_args_kwargs = _get_vllm_kwargs(llm_config) + if ( + engine_args_kwargs["tensor_parallel_size"] > 1 + and "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ + ): + LOGGER.warning( + "Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn' to avoid issues with " + "CUDA re-initialization." + ) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + engine_cls = AsyncLLMEngine if use_async else LLMEngine engine_args_cls = AsyncEngineArgs if use_async else EngineArgs engine_args = engine_args_cls(**engine_args_kwargs) # type: ignore