diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 0446d43f..73db6db1 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -402,7 +402,7 @@ def __init__( *, ctx: BaseContext, bridge_data: reGenBridgeData, - target_ram_overhead_bytes: int = 8 * 1024 * 1024 * 1024, + target_ram_overhead_bytes: int = 10 * 1024 * 1024 * 1024, target_vram_overhead_bytes_map: Mapping[int, int] | None = None, # FIXME max_inference_processes: int = 4, max_safety_processes: int = 1, @@ -935,15 +935,22 @@ def start_inference(self) -> None: if process_info.loaded_horde_model_name is None: continue - next_n_models = self.get_next_n_models(self.max_inference_processes) + next_n_models = list(self.get_next_n_models(self.max_inference_processes)) - if process_info.loaded_horde_model_name not in next_n_models: - process_info.pipe_connection.send( - HordeControlModelMessage( - control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_VRAM, - horde_model_name=process_info.loaded_horde_model_name, - ), - ) + # If the model would be used by another process soon, don't unload it + if ( + self.max_concurrent_inference_processes > 1 + and process_info.loaded_horde_model_name + in next_n_models[: self.max_concurrent_inference_processes - 1] + ): + continue + + process_info.pipe_connection.send( + HordeControlModelMessage( + control_flag=HordeControlFlag.UNLOAD_MODELS_FROM_VRAM, + horde_model_name=process_info.loaded_horde_model_name, + ), + ) time.sleep(0.1) logger.info(f"Starting inference for job {next_job.id_} on process {process_with_model.process_id}")