From 90de09d28a68631f294fd5fe816eb31759170c06 Mon Sep 17 00:00:00 2001 From: Christopher Childs Date: Mon, 11 Dec 2023 20:52:12 -0800 Subject: [PATCH] Cut down on duplicate model unload requests --- .../process_management/process_manager.py | 72 ++++++++++++------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index a1aaa60f..3036e705 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -90,6 +90,8 @@ class HordeProcessInfo: """Last time we updated the process info. If we're regularly working, then this value should change frequently.""" loaded_horde_model_name: str | None = None """The name of the horde model that is (supposedly) currently loaded in this process.""" + last_control_flag: HordeControlFlag | None + """The last control flag sent, to avoid duplication.""" ram_usage_bytes: int = 0 """The amount of RAM used by this process.""" @@ -1000,13 +1002,22 @@ def receive_and_handle_process_messages(self) -> None: message.process_id in self._process_map and message.horde_model_state != self._process_map[message.process_id].loaded_horde_model_name ): - loaded_message = f"Process {message.process_id} has model {message.horde_model_name} loaded. " + if message.horde_model_state == ModelLoadState.LOADED_IN_VRAM: + loaded_message = ( + f"Process {message.process_id} just finished inference, and has " + f"{message.horde_model_name} in VRAM." + ) + logger.debug(loaded_message) + elif message.horde_model_state == ModelLoadState.LOADED_IN_RAM: + loaded_message = ( + f"Process {message.process_id} moved model {message.horde_model_name} to system RAM. " + ) - if message.time_elapsed is not None: - # round to 2 decimal places - loaded_message += f"Loading took {message.time_elapsed:.2f} seconds" + if message.time_elapsed is not None: + # round to 2 decimal places + loaded_message += f"Loading took {message.time_elapsed:.2f} seconds." - logger.info(loaded_message) + logger.info(loaded_message) self._process_map.update_entry( process_id=message.process_id, @@ -1222,6 +1233,7 @@ def preload_models(self) -> bool: seamless_tiling_enabled=seamless_tiling_enabled, ), ) + available_process.last_control_flag = HordeControlFlag.PRELOAD_MODEL self._horde_model_map.update_entry( horde_model_name=job.model, @@ -1296,12 +1308,14 @@ def start_inference(self) -> None: if process_info.loaded_horde_model_name in next_n_models: continue - process_info.pipe_connection.send( - HordeControlModelMessage( - control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_VRAM, - horde_model_name=process_info.loaded_horde_model_name, - ), - ) + if process_info.last_control_flag != HordeControlFlag.UNLOAD_MODELS_FROM_VRAM: + process_info.pipe_connection.send( + HordeControlModelMessage( + control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_VRAM, + horde_model_name=process_info.loaded_horde_model_name, + ), + ) + process_info.last_control_flag = HordeControlFlag.UNLOAD_MODELS_FROM_VRAM logger.info(f"Starting inference for job {next_job.id_} on process {process_with_model.process_id}") # region Log job info @@ -1346,6 +1360,7 @@ def start_inference(self) -> None: sdk_api_job_info=next_job, ), ) + process_with_model.last_control_flag = HordeControlFlag.START_INFERENCE def unload_from_ram(self, process_id: int) -> None: """Unload models from a process, either from VRAM or both VRAM and system RAM. @@ -1364,23 +1379,26 @@ def unload_from_ram(self, process_id: int) -> None: if not self._horde_model_map.is_model_loaded(process_info.loaded_horde_model_name): raise ValueError(f"process_id {process_id} is loaded with a model that is not loaded") - process_info.pipe_connection.send( - HordeControlModelMessage( - control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_RAM, - horde_model_name=process_info.loaded_horde_model_name, - ), - ) + if process_info.last_control_flag != HordeControlFlag.UNLOAD_MODELS_FROM_RAM: + process_info.pipe_connection.send( + HordeControlModelMessage( + control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_RAM, + horde_model_name=process_info.loaded_horde_model_name, + ), + ) - self._horde_model_map.update_entry( - horde_model_name=process_info.loaded_horde_model_name, - load_state=ModelLoadState.ON_DISK, - process_id=process_id, - ) + process_info.last_control_flag = HordeControlFlag.UNLOAD_MODELS_FROM_RAM - self._process_map.update_entry( - process_id=process_id, - loaded_horde_model_name=None, - ) + self._horde_model_map.update_entry( + horde_model_name=process_info.loaded_horde_model_name, + load_state=ModelLoadState.ON_DISK, + process_id=process_id, + ) + + self._process_map.update_entry( + process_id=process_id, + loaded_horde_model_name=None, + ) def get_next_n_models(self, n: int) -> set[str]: """Get the next n models that will be used in the job deque. @@ -2073,7 +2091,7 @@ async def _process_control_loop(self) -> None: self.replace_hung_processes() - self.unload_models() + # self.unload_models() if self._shutting_down: self.end_inference_processes()