From 51708121cdeccb33a7d393f72e1961a801c986ed Mon Sep 17 00:00:00 2001 From: Christopher Childs Date: Sat, 11 Nov 2023 19:37:38 -0800 Subject: [PATCH] Replace processes when a model is unloaded On Linux, it seems like there is a tremendous amount of memory allocated by something outside Python when you allow one worker to process many jobs on many different models. In order to limit the damage from that behavior, we'll try to replace the processes when the model is scheduled to be unloaded. I realize this probably makes it a _little_ harder to decouple a process from a model, but it's a huge stability improvement. This also switches the model management strategy a tiny bit by allocating a model to every open worker before trying to unload a model. Previously, you would have N processes = `threads` + `queue`, but jobs were very likely to be scheduled on only the `threads` number of workers. --- .../process_management/process_manager.py | 169 +++++++++++++----- 1 file changed, 129 insertions(+), 40 deletions(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 5520fed5..3dd3bb32 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -1,5 +1,7 @@ import asyncio import base64 +import collections +import datetime import multiprocessing import os import random @@ -81,6 +83,8 @@ class HordeProcessInfo: """The type of this process.""" last_process_state: HordeProcessState """The last known state of this process.""" + last_timestamp: datetime.datetime + """Last time we updated the process info. If we're regularly working, then this value should change frequently.""" loaded_horde_model_name: str | None = None """The name of the horde model that is (supposedly) currently loaded in this process.""" @@ -115,6 +119,7 @@ def __init__( self.process_id = process_id self.process_type = process_type self.last_process_state = last_process_state + self.last_timestamp = datetime.datetime.now() def is_process_busy(self) -> bool: """Return true if the process is actively engaged in a task. @@ -186,6 +191,9 @@ def update_entry( if process_id is not None: self.root[horde_model_name].process_id = process_id + def expire_entry(self, horde_model_name: str): + self.root.pop(horde_model_name, 'None') + def is_model_loaded(self, horde_model_name: str) -> bool: """Return true if the given model is loaded in any process.""" if horde_model_name not in self.root: @@ -242,6 +250,8 @@ def update_entry( if total_vram_bytes is not None: self[process_id].total_vram_bytes = total_vram_bytes + self[process_id].last_timestamp = datetime.datetime.now() + def num_inference_processes(self) -> int: """Return the number of inference processes.""" count = 0 @@ -260,6 +270,12 @@ def num_available_inference_processes(self) -> int: def get_first_available_inference_process(self) -> HordeProcessInfo | None: """Return the first available inference process, or None if there are none available.""" + for p in self.values(): + if p.process_type == HordeProcessType.INFERENCE \ + and p.last_process_state == HordeProcessState.WAITING_FOR_JOB \ + and p.loaded_horde_model_name is None: + return p + for p in self.values(): if p.process_type == HordeProcessType.INFERENCE and p.can_accept_job(): if p.last_process_state == HordeProcessState.PRELOADED_MODEL: @@ -336,7 +352,8 @@ def __repr__(self) -> str: base_string = "Processes: " for process_id, process_info in self.items(): if process_info.process_type == HordeProcessType.INFERENCE: - base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) " + base_string += (f"{process_id}: ({process_info.loaded_horde_model_name} " + f"[last event: {process_info.last_timestamp}]) ") else: base_string += f"{process_id}: ({process_info.process_type.name}) " base_string += f"{process_info.last_process_state.name}; " @@ -380,6 +397,21 @@ def is_job_checked_for_safety(self) -> bool: return self.censored is not None +class LRUCache: + def __init__(self, capacity): + self.capacity = capacity + self.cache = collections.OrderedDict() + + def append(self, key): + bumped = None + if key in self.cache: + self.cache.move_to_end(key) + elif len(self.cache) >= self.capacity: + bumped, _ = self.cache.popitem(last=False) + self.cache[key] = None + return bumped + + class HordeWorkerProcessManager: """Manages and controls processes to act as a horde worker.""" @@ -412,6 +444,9 @@ def max_concurrent_inference_processes(self) -> int: target_vram_overhead_bytes_map: Mapping[int, int] | None = None + process_timeout: datetime.timedelta + """Max amount of time a job can go without checking in with the main process manager""" + @property def max_queue_size(self) -> int: """The maximum number of jobs that can be queued.""" @@ -501,6 +536,8 @@ def num_total_processes(self) -> int: _shutting_down = False + _lru: LRUCache + def __init__( self, *, @@ -542,6 +579,8 @@ def __init__( self._inference_semaphore = Semaphore(self._max_concurrent_inference_processes, ctx=ctx) self.max_inference_processes = self.bridge_data.queue_size + self.bridge_data.max_threads + self._lru = LRUCache(self.max_inference_processes) + self.process_timeout = datetime.timedelta(minutes=5) # If there is only one model to load and only one inference process, then we can only run one job at a time # and there is no point in having more than one inference process @@ -743,33 +782,34 @@ def start_inference_processes(self) -> None: for _ in range(num_processes_to_start): # Create a two-way communication pipe for the parent and child processes pid = len(self._process_map) - pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True) - - # Create a new process that will run the start_inference_process function - process = multiprocessing.Process( - target=start_inference_process, - args=( - pid, - self._process_message_queue, - child_pipe_connection, - self._inference_semaphore, - self._disk_lock, - ), - ) - - process.start() - - # Add the process to the process map - self._process_map[pid] = HordeProcessInfo( - mp_process=process, - pipe_connection=pipe_connection, - process_id=pid, - process_type=HordeProcessType.INFERENCE, - last_process_state=HordeProcessState.PROCESS_STARTING, - ) + self._start_inference_process(pid) logger.info(f"Started inference process (id: {pid})") + def _start_inference_process(self, pid): + logger.info(f"Starting inference process on PID {pid}") + pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True) + # Create a new process that will run the start_inference_process function + process = multiprocessing.Process( + target=start_inference_process, + args=( + pid, + self._process_message_queue, + child_pipe_connection, + self._inference_semaphore, + self._disk_lock, + ), + ) + process.start() + # Add the process to the process map + self._process_map[pid] = HordeProcessInfo( + mp_process=process, + pipe_connection=pipe_connection, + process_id=pid, + process_type=HordeProcessType.INFERENCE, + last_process_state=HordeProcessState.PROCESS_STARTING, + ) + def end_inference_processes(self) -> None: """End any inference processes above the configured limit, or all of them if shutting down.""" if len(self.job_deque) > 0 and len(self.job_deque) != len(self.jobs_in_progress): @@ -778,19 +818,23 @@ def end_inference_processes(self) -> None: # Get the process to end process_info = self._process_map._get_first_inference_process_to_kill() - if process_info is None: - return + if process_info is not None: + self._end_inference_process(process_info) + def _end_inference_process(self, process_info): # Send the process a message to end process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS)) - # Update the process map self._process_map.update_entry(process_id=process_info.process_id) - logger.info(f"Ended inference process {process_info.process_id}") - # Join the process with a timeout of 1 second process_info.mp_process.join(timeout=1) + process_info.mp_process.kill() + + def _replace_inference_process(self, process_info): + logger.debug(f"Replacing {process_info}") + self._end_inference_process(process_info) + self._start_inference_process(process_info.process_id) total_num_completed_jobs: int = 0 @@ -1058,11 +1102,31 @@ def preload_models(self) -> bool: if model_is_loaded: continue - available_process = self._process_map.get_first_available_inference_process() + available_process = None + model_to_unload = self._lru.append(job.model) + if model_to_unload is not None: + for p in self._process_map.values(): + if p.loaded_horde_model_name == model_to_unload and \ + (p.last_process_state == HordeProcessState.INFERENCE_COMPLETE or \ + p.last_process_state == HordeProcessState.WAITING_FOR_JOB): + available_process = p + if available_process is None: + available_process = self._process_map.get_first_available_inference_process() if available_process is None: return False + if available_process.last_process_state != HordeProcessState.WAITING_FOR_JOB \ + and available_process.loaded_horde_model_name is not None: + # We're going to restart the process and then exit the loop, because + # available_process is very quickly _not_ going to be available. + # We also don't want to block waiting for the newly forked job to become + # available, so we'll wait for it to become ready before scheduling a model + # to be loaded on it. + self._replace_inference_process(available_process) + self._horde_model_map.expire_entry(available_process.loaded_horde_model_name) + return False + logger.debug(f"Preloading model {job.model} on process {available_process.process_id}") logger.debug(f"Available inference processes: {self._process_map}") logger.debug(f"Horde model map: {self._horde_model_map}") @@ -1145,11 +1209,7 @@ def start_inference(self) -> None: next_n_models = list(self.get_next_n_models(self.max_inference_processes)) # If the model would be used by another process soon, don't unload it - if ( - self.max_concurrent_inference_processes > 1 - and process_info.loaded_horde_model_name - in next_n_models[: self.max_concurrent_inference_processes - 1] - ): + if process_info.loaded_horde_model_name in next_n_models: continue process_info.pipe_connection.send( @@ -1770,6 +1830,8 @@ async def api_job_pop(self) -> None: async with self._job_deque_lock, self._job_pop_timestamps_lock: self.job_deque.append(job_pop_response) + jobs = list(map(lambda x: f'<{x.id_}: {x.model}>', self.job_deque)) + logger.info(f'Job queue: {", ".join(jobs)}') # self._testing_jobs_added += 1 self.job_pop_timestamps[str(job_pop_response.id_)] = time.time() @@ -1885,6 +1947,7 @@ async def _process_control_loop(self) -> None: self.start_inference_processes() while True: + logger.debug("_process_control_loop looped") try: if self.stable_diffusion_reference is None: return @@ -1908,6 +1971,8 @@ async def _process_control_loop(self) -> None: async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: self.receive_and_handle_process_messages() + self.replace_hung_processes() + self.unload_models() if self._shutting_down: @@ -1921,6 +1986,8 @@ async def _process_control_loop(self) -> None: logger.info(f"{self._process_map}") logger.info(f"Threads being used: {self._max_concurrent_inference_processes}") logger.info(f"Number of jobs popped: {len(self.job_deque)}") + jobs = list(map(lambda x: f'<{x.id_}: {x.model}>', self.job_deque)) + logger.info(f'Job queue: {", ".join(jobs)}') logger.info(f"Number of jobs in progress: {len(self.jobs_in_progress)}") logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}") logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}") @@ -2002,13 +2069,28 @@ async def _bridge_data_loop(self) -> None: except CancelledError: self._shutting_down = True + @staticmethod + async def _handle_exception(task): + try: + await task + except Exception as e: + logger.error(f"Caught exception in task {task}: {e}") + async def _main_loop(self) -> None: # Run both loops concurrently + process_control_loop = asyncio.create_task(self._process_control_loop(), name="process_control_loop") + api_call_loop = asyncio.create_task(self._api_call_loop(), name="api_call_loop") + job_submit_loop = asyncio.create_task(self._job_submit_loop(), name="job_submit_loop") + bridge_data_loop = asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop") + process_control_loop.add_done_callback(self._handle_exception) + api_call_loop.add_done_callback(self._handle_exception) + job_submit_loop.add_done_callback(self._handle_exception) + bridge_data_loop.add_done_callback(self._handle_exception) await asyncio.gather( - asyncio.create_task(self._process_control_loop(), name="process_control_loop"), - asyncio.create_task(self._api_call_loop(), name="api_call_loop"), - asyncio.create_task(self._job_submit_loop(), name="job_submit_loop"), - asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop"), + process_control_loop, + api_call_loop, + job_submit_loop, + bridge_data_loop, ) _caught_sigints = 0 @@ -2046,3 +2128,10 @@ def shutdown() -> None: sys.exit(0) threading.Thread(target=shutdown).start() + + def replace_hung_processes(self): + now = datetime.datetime.now() + for pid, process_info in self._process_map.items(): + if (now - process_info.last_timestamp) > self.process_timeout: + logger.error(f"{process_info} has exceeded its timeout and will be replaced") + self._replace_inference_process(process_info)