Skip to content

Commit

Permalink
Cut down on duplicate model unload requests
Browse files Browse the repository at this point in the history
  • Loading branch information
zten authored and tazlin committed Dec 14, 2023
1 parent 7a1b1d3 commit 90de09d
Showing 1 changed file with 45 additions and 27 deletions.
72 changes: 45 additions & 27 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class HordeProcessInfo:
"""Last time we updated the process info. If we're regularly working, then this value should change frequently."""
loaded_horde_model_name: str | None = None
"""The name of the horde model that is (supposedly) currently loaded in this process."""
last_control_flag: HordeControlFlag | None
"""The last control flag sent, to avoid duplication."""

ram_usage_bytes: int = 0
"""The amount of RAM used by this process."""
Expand Down Expand Up @@ -1000,13 +1002,22 @@ def receive_and_handle_process_messages(self) -> None:
message.process_id in self._process_map
and message.horde_model_state != self._process_map[message.process_id].loaded_horde_model_name
):
loaded_message = f"Process {message.process_id} has model {message.horde_model_name} loaded. "
if message.horde_model_state == ModelLoadState.LOADED_IN_VRAM:
loaded_message = (
f"Process {message.process_id} just finished inference, and has "
f"{message.horde_model_name} in VRAM."
)
logger.debug(loaded_message)
elif message.horde_model_state == ModelLoadState.LOADED_IN_RAM:
loaded_message = (
f"Process {message.process_id} moved model {message.horde_model_name} to system RAM. "
)

if message.time_elapsed is not None:
# round to 2 decimal places
loaded_message += f"Loading took {message.time_elapsed:.2f} seconds"
if message.time_elapsed is not None:
# round to 2 decimal places
loaded_message += f"Loading took {message.time_elapsed:.2f} seconds."

logger.info(loaded_message)
logger.info(loaded_message)

self._process_map.update_entry(
process_id=message.process_id,
Expand Down Expand Up @@ -1222,6 +1233,7 @@ def preload_models(self) -> bool:
seamless_tiling_enabled=seamless_tiling_enabled,
),
)
available_process.last_control_flag = HordeControlFlag.PRELOAD_MODEL

self._horde_model_map.update_entry(
horde_model_name=job.model,
Expand Down Expand Up @@ -1296,12 +1308,14 @@ def start_inference(self) -> None:
if process_info.loaded_horde_model_name in next_n_models:
continue

process_info.pipe_connection.send(
HordeControlModelMessage(
control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_VRAM,
horde_model_name=process_info.loaded_horde_model_name,
),
)
if process_info.last_control_flag != HordeControlFlag.UNLOAD_MODELS_FROM_VRAM:
process_info.pipe_connection.send(
HordeControlModelMessage(
control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_VRAM,
horde_model_name=process_info.loaded_horde_model_name,
),
)
process_info.last_control_flag = HordeControlFlag.UNLOAD_MODELS_FROM_VRAM

logger.info(f"Starting inference for job {next_job.id_} on process {process_with_model.process_id}")
# region Log job info
Expand Down Expand Up @@ -1346,6 +1360,7 @@ def start_inference(self) -> None:
sdk_api_job_info=next_job,
),
)
process_with_model.last_control_flag = HordeControlFlag.START_INFERENCE

def unload_from_ram(self, process_id: int) -> None:
"""Unload models from a process, either from VRAM or both VRAM and system RAM.
Expand All @@ -1364,23 +1379,26 @@ def unload_from_ram(self, process_id: int) -> None:
if not self._horde_model_map.is_model_loaded(process_info.loaded_horde_model_name):
raise ValueError(f"process_id {process_id} is loaded with a model that is not loaded")

process_info.pipe_connection.send(
HordeControlModelMessage(
control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_RAM,
horde_model_name=process_info.loaded_horde_model_name,
),
)
if process_info.last_control_flag != HordeControlFlag.UNLOAD_MODELS_FROM_RAM:
process_info.pipe_connection.send(
HordeControlModelMessage(
control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_RAM,
horde_model_name=process_info.loaded_horde_model_name,
),
)

self._horde_model_map.update_entry(
horde_model_name=process_info.loaded_horde_model_name,
load_state=ModelLoadState.ON_DISK,
process_id=process_id,
)
process_info.last_control_flag = HordeControlFlag.UNLOAD_MODELS_FROM_RAM

self._process_map.update_entry(
process_id=process_id,
loaded_horde_model_name=None,
)
self._horde_model_map.update_entry(
horde_model_name=process_info.loaded_horde_model_name,
load_state=ModelLoadState.ON_DISK,
process_id=process_id,
)

self._process_map.update_entry(
process_id=process_id,
loaded_horde_model_name=None,
)

def get_next_n_models(self, n: int) -> set[str]:
"""Get the next n models that will be used in the job deque.
Expand Down Expand Up @@ -2073,7 +2091,7 @@ async def _process_control_loop(self) -> None:

self.replace_hung_processes()

self.unload_models()
# self.unload_models()

if self._shutting_down:
self.end_inference_processes()
Expand Down

0 comments on commit 90de09d

Please sign in to comment.