From 13ee02cb79b3959934e4e3dee5cb978c69630c03 Mon Sep 17 00:00:00 2001 From: walmartbaggggggg <148656924+walmartbaggggggg@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:56:49 -0400 Subject: [PATCH] Update process_manager.py Cache removal, free up ram. --- .../process_management/process_manager.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 88f944f7..507e5780 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -13,7 +13,6 @@ from multiprocessing.context import BaseContext from multiprocessing.synchronize import Lock as Lock_MultiProcessing from multiprocessing.synchronize import Semaphore - import aiohttp import PIL import PIL.Image @@ -60,8 +59,8 @@ ModelInfo, ModelLoadState, ) -from horde_worker_regen.process_management.worker_entry_points import start_inference_process, start_safety_process +from horde_worker_regen.process_management.worker_entry_points import start_inference_process, start_safety_process try: from multiprocessing.connection import PipeConnection as Connection # type: ignore except Exception: @@ -1206,6 +1205,7 @@ def unload_from_ram(self, process_id: int) -> None: ) def get_next_n_models(self, n: int) -> set[str]: + """Get the next n models that will be used in the job deque. Args: @@ -1477,13 +1477,16 @@ async def api_submit_job(self) -> None: f"kudos. Job popped {time_taken} seconds ago and took {completed_job_info.time_to_generate:.2f} " f"to generate. ({kudos_per_second:.2f} kudos/second. 0.4 or greater is ideal)", ) + torch.cuda.empty_cache() + logger.info("Cache removal success") # If the job was faulted, log an error else: logger.error( f"{job_info.id_} faulted, not submitting for kudos. Job popped {time_taken} seconds ago and took " f"{completed_job_info.time_to_generate:.2f} to generate.", ) - + torch.cuda.empty_cache() + logger.info("Cache removal success") # If the job took a long time to generate, log a warning (unless speed warnings are suppressed) if not self.bridge_data.suppress_speed_warnings: if job_submit_response.reward > 0 and (job_submit_response.reward / time_taken) < 0.1: @@ -1708,7 +1711,6 @@ async def api_job_pop(self) -> None: return logger.info(f"Popped job {job_pop_response.id_} (model: {job_pop_response.model})") - # region TODO: move to horde_sdk if job_pop_response.payload.seed is None: # TODO # FIXME logger.warning(f"Job {job_pop_response.id_} has no seed!")