From 3fceb7504bf464231b0001c9c00af88b90425548 Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 25 Mar 2024 07:37:59 -0400 Subject: [PATCH 01/14] fix: don't send heartbeat until the first step --- horde_worker_regen/process_management/inference_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/horde_worker_regen/process_management/inference_process.py b/horde_worker_regen/process_management/inference_process.py index 7435bbaf..b69fa296 100644 --- a/horde_worker_regen/process_management/inference_process.py +++ b/horde_worker_regen/process_management/inference_process.py @@ -431,7 +431,7 @@ def progress_callback( except Exception as e: logger.error(f"Failed to release inference semaphore: {type(e).__name__} {e}") - if progress_report.comfyui_progress is not None and progress_report.comfyui_progress.current_step >= 0: + if progress_report.comfyui_progress is not None and progress_report.comfyui_progress.current_step > 0: self.send_heartbeat_message(heartbeat_type=HordeHeartbeatType.INFERENCE_STEP) else: self.send_heartbeat_message(heartbeat_type=HordeHeartbeatType.PIPELINE_STATE_CHANGE) From f94c204dde9a46cbc561337539b3e985f5f8da6d Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 25 Mar 2024 17:22:56 -0400 Subject: [PATCH 02/14] fix: corrects heartbeat logic This prevents the heartbeat rate limiting from coming into play. Logic else where relies on heartbeat messages changing being noticed, but the rate limiting would occasionally prevent messages from being sent. --- horde_worker_regen/process_management/horde_process.py | 8 ++++++-- horde_worker_regen/process_management/process_manager.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/horde_worker_regen/process_management/horde_process.py b/horde_worker_regen/process_management/horde_process.py index 493a8da6..cd28410f 100644 --- a/horde_worker_regen/process_management/horde_process.py +++ b/horde_worker_regen/process_management/horde_process.py @@ -140,14 +140,17 @@ def send_process_state_change_message( _heartbeat_limit_interval_seconds: float = 1.0 _last_heartbeat_time: float = 0.0 + _last_heartbeat_type: HordeHeartbeatType = HordeHeartbeatType.OTHER def send_heartbeat_message(self, heartbeat_type: HordeHeartbeatType) -> None: """Send a heartbeat message to the main process, indicating that the process is still alive. Note that this will only send a heartbeat message if the last heartbeat was sent more than - `_heartbeat_limit_interval_seconds` ago. + `_heartbeat_limit_interval_seconds` ago or if the heartbeat type has changed. """ - if (time.time() - self._last_heartbeat_time) < self._heartbeat_limit_interval_seconds: # FIXME? + if (heartbeat_type != self._last_heartbeat_type) and ( + time.time() - self._last_heartbeat_time + ) < self._heartbeat_limit_interval_seconds: return message = HordeProcessHeartbeatMessage( @@ -158,6 +161,7 @@ def send_heartbeat_message(self, heartbeat_type: HordeHeartbeatType) -> None: ) self.process_message_queue.put(message) + self._last_heartbeat_type = heartbeat_type self._last_heartbeat_time = time.time() @abstractmethod diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index acbe14c9..cf6e6b24 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -454,7 +454,7 @@ def is_stuck_on_inference(self, process_id: int) -> bool: return False if ( self[process_id].last_heartbeat_type == HordeHeartbeatType.INFERENCE_STEP - and (time.time() - self[process_id].last_heartbeat_timestamp) > 30 + and (time.time() - self[process_id].last_heartbeat_timestamp) > 45 ): return True return False From 3b8e1268a4c3982f2e938fee13519b26a431fc44 Mon Sep 17 00:00:00 2001 From: tazlin Date: Wed, 27 Mar 2024 10:49:53 -0400 Subject: [PATCH 03/14] fix: avoid death spiral on maint mode/total proc. recovery --- .../process_management/process_manager.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index cf6e6b24..b2bc878a 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -3003,6 +3003,7 @@ async def api_job_pop(self) -> None: else: logger.error(f"Failed to pop job (API Error): {job_pop_response}") self._job_pop_frequency = self._error_job_pop_frequency + self._last_pop_no_jobs_available = True return except Exception as e: @@ -3601,6 +3602,12 @@ def replace_hung_processes(self) -> bool: """Replaces processes that haven't checked in since `process_timeout` seconds in bridgeData.""" now = time.time() + import threading + + def timed_unset_recently_recovered() -> None: + time.sleep(60) + self._recently_recovered = False + # If every process hasn't done anything for a while or if we haven't submitted a job for a while, # AND the last job pop returned a job, we're in a black hole and we need to exit because none of the ways to # recover worked @@ -3632,12 +3639,6 @@ def replace_hung_processes(self) -> bool: if process_info.process_type == HordeProcessType.INFERENCE: self._replace_inference_process(process_info) - def timed_unset_recently_recovered() -> None: - time.sleep(60) - self._recently_recovered = False - - import threading - threading.Thread(target=timed_unset_recently_recovered).start() return True @@ -3645,12 +3646,16 @@ def timed_unset_recently_recovered() -> None: if self._shutting_down: return False + if self._last_pop_no_jobs_available or self._recently_recovered: + return False + any_replaced = False for process_info in self._process_map.values(): if self._process_map.is_stuck_on_inference(process_info.process_id): logger.error(f"{process_info} seems to be stuck mid inference, replacing it") self._replace_inference_process(process_info) any_replaced = True + self._recently_recovered = True else: conditions: list[tuple[float, HordeProcessState, str]] = [ ( @@ -3682,6 +3687,10 @@ def timed_unset_recently_recovered() -> None: for timeout, state, error_message in conditions: if self._check_and_replace_process(process_info, timeout, state, error_message): any_replaced = True + self._recently_recovered = True break + if any_replaced: + threading.Thread(target=timed_unset_recently_recovered).start() + return any_replaced From 389474fb78ea20e5cf0fe058f5482069f103e20b Mon Sep 17 00:00:00 2001 From: db0 Date: Sat, 23 Mar 2024 17:11:54 +0100 Subject: [PATCH 04/14] feat: initial remix support --- .gitignore | 3 + convert_config_to_env.py | 1 - horde-bridge.cmd | 2 +- horde_worker_regen/__init__.py | 2 +- horde_worker_regen/bridge_data/load_config.py | 2 +- horde_worker_regen/download_models.py | 2 +- .../process_management/inference_process.py | 6 +- .../process_management/main_entry_point.py | 1 - .../process_management/process_manager.py | 286 ++++++++++++------ horde_worker_regen/run_worker.py | 2 +- pyproject.toml | 3 + requirements.txt | 2 +- tests/test_bridge_data.py | 2 +- 13 files changed, 215 insertions(+), 99 deletions(-) diff --git a/.gitignore b/.gitignore index c72168e0..a1e4f4fb 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,6 @@ cython_debug/ bin/* conda/* models/* +clip_blip/* +hf_transformers/* +horde_model_reference/* diff --git a/convert_config_to_env.py b/convert_config_to_env.py index c0712c60..d6c738f4 100644 --- a/convert_config_to_env.py +++ b/convert_config_to_env.py @@ -12,7 +12,6 @@ import argparse from horde_model_reference.model_reference_manager import ModelReferenceManager - from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat diff --git a/horde-bridge.cmd b/horde-bridge.cmd index 6d41a9c8..d774dfe8 100644 --- a/horde-bridge.cmd +++ b/horde-bridge.cmd @@ -4,7 +4,7 @@ cd /d %~dp0 : This first call to runtime activates the environment for the rest of the script call runtime python -s -m pip -V -call python -s -m pip install horde_sdk~=0.8.3 horde_model_reference~=0.6.3 hordelib~=2.7.6 -U +call python -s -m pip install horde_sdk~=0.9.0 horde_model_reference~=0.6.3 hordelib~=2.7.4 -U if %ERRORLEVEL% NEQ 0 ( echo "Please run update-runtime.cmd." GOTO END diff --git a/horde_worker_regen/__init__.py b/horde_worker_regen/__init__.py index bc32109d..b4009f49 100644 --- a/horde_worker_regen/__init__.py +++ b/horde_worker_regen/__init__.py @@ -8,4 +8,4 @@ ASSETS_FOLDER_PATH = Path(__file__).parent / "assets" -__version__ = "4.3.9" +__version__ = "5.0.0" diff --git a/horde_worker_regen/bridge_data/load_config.py b/horde_worker_regen/bridge_data/load_config.py index 6da41d6d..56eced63 100644 --- a/horde_worker_regen/bridge_data/load_config.py +++ b/horde_worker_regen/bridge_data/load_config.py @@ -6,13 +6,13 @@ from enum import auto from pathlib import Path -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient from horde_sdk.ai_horde_worker.model_meta import ImageModelLoadResolver from loguru import logger from ruamel.yaml import YAML from strenum import StrEnum +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data import AIWORKER_REGEN_PREFIX from horde_worker_regen.bridge_data.data_model import reGenBridgeData diff --git a/horde_worker_regen/download_models.py b/horde_worker_regen/download_models.py index 0f0cdbc1..a4b5cc2e 100644 --- a/horde_worker_regen/download_models.py +++ b/horde_worker_regen/download_models.py @@ -12,9 +12,9 @@ def download_all_models( if not load_config_from_env_vars: load_env_vars_from_config() - from horde_model_reference.model_reference_manager import ModelReferenceManager from loguru import logger + from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME diff --git a/horde_worker_regen/process_management/inference_process.py b/horde_worker_regen/process_management/inference_process.py index b69fa296..1b33fb9c 100644 --- a/horde_worker_regen/process_management/inference_process.py +++ b/horde_worker_regen/process_management/inference_process.py @@ -451,9 +451,11 @@ def start_inference(self, job_info: ImageGenerateJobPopResponse) -> list[Resulti self._is_busy = True try: logger.info(f"Starting inference for job(s) {job_info.ids}") + esi_count = len(job_info.extra_source_images) if job_info.extra_source_images is not None else 0 logger.debug( - f"has source_image: {job_info.source_image is not None} " - f"has source_mask: {job_info.source_mask is not None}", + f"has source_image: {job_info.source_image is not None}, " + f"has source_mask: {job_info.source_mask is not None}, " + f"extra_source_images: {esi_count}", ) logger.debug(f"{job_info.payload.model_dump(exclude={'prompt'})}") diff --git a/horde_worker_regen/process_management/main_entry_point.py b/horde_worker_regen/process_management/main_entry_point.py index 73051216..65d3cadc 100644 --- a/horde_worker_regen/process_management/main_entry_point.py +++ b/horde_worker_regen/process_management/main_entry_point.py @@ -1,7 +1,6 @@ from multiprocessing.context import BaseContext from horde_model_reference.model_reference_manager import ModelReferenceManager - from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.process_management.process_manager import HordeWorkerProcessManager diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index b2bc878a..b175f0dd 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -26,13 +26,11 @@ import psutil import yarl from aiohttp import ClientSession -from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY -from horde_model_reference.model_reference_manager import ModelReferenceManager -from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_sdk import RequestErrorResponse from horde_sdk.ai_horde_api import GENERATION_STATE from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncClientSession, AIHordeAPIAsyncSimpleClient from horde_sdk.ai_horde_api.apimodels import ( + ExtraSourceImageEntry, FindUserRequest, FindUserResponse, GenMetadataEntry, @@ -44,9 +42,11 @@ from horde_sdk.ai_horde_api.fields import JobID from loguru import logger from pydantic import BaseModel, ConfigDict, RootModel, ValidationError -from typing_extensions import Any import horde_worker_regen +from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY +from horde_model_reference.model_reference_manager import ModelReferenceManager +from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME, KNOWN_SLOW_MODELS_DIFFICULTIES, VRAM_HEAVY_MODELS @@ -712,16 +712,88 @@ class JobSubmitState(enum.Enum): # TODO: Split into a new file """The job submit faulted for some reason.""" -class PendingSubmitJob(BaseModel): # TODO: Split into a new file +class PendingJob(BaseModel): + """Base class for all PendingJobs async tasks.""" + + state: JobSubmitState = JobSubmitState.PENDING + _max_consecutive_failed_job_submits: int = 10 + _consecutive_failed_job_submits: int = 0 + + @property + def is_finished(self) -> bool: + """Return true if the job submit has finished.""" + return self.state != JobSubmitState.PENDING + + @property + def is_faulted(self) -> bool: + """Return true if the job submit has faulted.""" + return self.state == JobSubmitState.FAULTED + + @property + def retry_attempts_string(self) -> str: + """Return a string containing the number of consecutive failed job submits and the maximum allowed.""" + return f"{self._consecutive_failed_job_submits}/{self._max_consecutive_failed_job_submits}" + + def retry(self) -> None: + """Mark the job as needing to be retried. Fault the job if it has been retried too many times.""" + self._consecutive_failed_job_submits += 1 + if self._consecutive_failed_job_submits > self._max_consecutive_failed_job_submits: + self.state = JobSubmitState.FAULTED + + def succeed(self) -> None: + """Mark the job as successfully submitted.""" + self.state = JobSubmitState.SUCCESS + + def fault(self) -> None: + """Mark the job as faulted.""" + self.state = JobSubmitState.FAULTED + + +class PendingSourceDownloadJob(PendingJob): + """Information about a source image to download from the horde.""" + + job_pop_response: ImageGenerateJobPopResponse + field_name: str + download_url: str | None + esi_index: int | None = None + image_b64: str | None = None + fault_metadata: GenMetadataEntry | None = None + + @property + def log_reference(self) -> str: + """Returns a string identifying the source image for logs.""" + log_reference = self.field_name + if self.esi_index is not None: + log_reference = f"{self.field_name}_{self.esi_index}" + return log_reference + + def retry(self, metadata_value: METADATA_VALUE) -> None: + """Marks this task to be retried. Adds GenMetadataEntry if retries exceeded.""" + super().retry() + if self.is_faulted: + self.fault_metadata = GenMetadataEntry( + type=METADATA_TYPE[self.field_name], + value=metadata_value, + ref=str(self.esi_index), + ) + + def fault(self, metadata_value: METADATA_VALUE) -> None: + """Faults this task and adds a GenMetadataEntry.""" + self.fault_metadata = GenMetadataEntry( + type=METADATA_TYPE[self.field_name], + value=metadata_value, + ref=str(self.esi_index), + ) + super().fault() + + +class PendingSubmitJob(PendingJob): # TODO: Split into a new file """Information about a job to submit to the horde.""" completed_job_info: HordeJobInfo gen_iter: int - state: JobSubmitState = JobSubmitState.PENDING kudos_reward: int = 0 kudos_per_second: float = 0.0 - _max_consecutive_failed_job_submits: int = 10 - _consecutive_failed_job_submits: int = 0 @property def image_result(self) -> HordeImageResult | None: @@ -742,32 +814,11 @@ def r2_upload(self) -> str: return "" # FIXME: Is this ever None? Or just a bad declaration on sdk? return self.completed_job_info.sdk_api_job_info.r2_uploads[self.gen_iter] - @property - def is_finished(self) -> bool: - """Return true if the job submit has finished.""" - return self.state != JobSubmitState.PENDING - - @property - def is_faulted(self) -> bool: - """Return true if the job submit has faulted.""" - return self.state == JobSubmitState.FAULTED - - @property - def retry_attempts_string(self) -> str: - """Return a string containing the number of consecutive failed job submits and the maximum allowed.""" - return f"{self._consecutive_failed_job_submits}/{self._max_consecutive_failed_job_submits}" - @property def batch_count(self) -> int: """Return the number of jobs in the batch.""" return len(self.completed_job_info.sdk_api_job_info.ids) - def retry(self) -> None: - """Mark the job as needing to be retried. Fault the job if it has been retried too many times.""" - self._consecutive_failed_job_submits += 1 - if self._consecutive_failed_job_submits > self._max_consecutive_failed_job_submits: - self.state = JobSubmitState.FAULTED - def succeed(self, kudos_reward: int, kudos_per_second: float) -> None: """Mark the job as successfully submitted. @@ -777,11 +828,7 @@ def succeed(self, kudos_reward: int, kudos_per_second: float) -> None: """ self.kudos_reward = kudos_reward self.kudos_per_second = kudos_per_second - self.state = JobSubmitState.SUCCESS - - def fault(self) -> None: - """Mark the job as faulted.""" - self.state = JobSubmitState.FAULTED + super().succeed() class NextJobAndProcess(BaseModel): @@ -2747,70 +2794,134 @@ def should_wait_for_pending_megapixelsteps(self) -> bool: return self.get_pending_megapixelsteps() > self._max_pending_megapixelsteps + async def download_source_image(self, new_download: PendingSourceDownloadJob) -> PendingSourceDownloadJob: + """Downloads a single source image asynchronously.""" + if new_download.download_url is not None and "https://" not in new_download.download_url: + new_download.fault(METADATA_VALUE.download_failed) + return new_download + logger.debug(f"Starting download of {new_download.log_reference}") + try: + # self.job_faults[job_pop_response.id_].append(new_meta_entry) + # logger.error(f"Failed to download {new_download.log_reference} after {fail_count} attempts") + response = await self._aiohttp_session.get( + new_download.download_url, + timeout=aiohttp.ClientTimeout(total=10), + ) + response.raise_for_status() + + content = await response.content.read() + new_download.image_b64 = base64.b64encode(content).decode("utf-8") + + logger.debug(f"Downloaded {new_download.log_reference} for job {new_download.job_pop_response.id_}") + except Exception as err: + logger.debug(f"{type(err)}: {err}") + logger.warning(f"Failed to download {new_download.log_reference}: {err}") + new_download.retry(METADATA_VALUE.download_failed) + return new_download + + try: + if new_download.image_b64 is not None: + image = PIL.Image.open(BytesIO(base64.b64decode(new_download.image_b64))) + image.verify() + except Exception as err: + logger.error(f"Failed to verify {new_download.log_reference}: {err}") + if new_download.job_pop_response.id_ is None: + raise ValueError("job_pop_response.id_ is None") from err + new_download.fault(METADATA_VALUE.parse_failed) + new_download.succeed() + return new_download + async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse) -> ImageGenerateJobPopResponse: # Adding this to stop mypy complaining if job_pop_response.id_ is None: logger.error("Received ImageGenerateJobPopResponse with id_ is None. Please let the devs know!") return job_pop_response - image_fields: list[str] = ["source_image", "source_mask"] - new_response_dict: dict[str, Any] = job_pop_response.model_dump(by_alias=True) + download_tasks: list[Task[PendingSourceDownloadJob]] = [] + finished_download_tasks: list[PendingSourceDownloadJob] = [] - # TODO: Move this into horde_sdk - for field_name in image_fields: - field_value = new_response_dict[field_name] - if field_value is not None and "https://" in field_value: - fail_count = 0 - while True: - try: - if fail_count >= 10: - new_meta_entry = GenMetadataEntry( - type=METADATA_TYPE[field_name], - value=METADATA_VALUE.download_failed, - ) - self.job_faults[job_pop_response.id_].append(new_meta_entry) - logger.error(f"Failed to download {field_name} after {fail_count} attempts") - break - response = await self._aiohttp_session.get( - field_value, - timeout=aiohttp.ClientTimeout(total=10), - ) - response.raise_for_status() - - content = await response.content.read() - - new_response_dict[field_name] = base64.b64encode(content).decode("utf-8") - - logger.debug(f"Downloaded {field_name} for job {job_pop_response.id_}") - break - except Exception as e: - logger.debug(f"{type(e)}: {e}") - logger.warning(f"Failed to download {field_name}: {e}") - fail_count += 1 - await asyncio.sleep(0.5) + if job_pop_response.source_image is not None: + new_download = PendingSourceDownloadJob( + job_pop_response=job_pop_response, + field_name="source_image", + download_url=job_pop_response.source_image, + ) + download_tasks.append(asyncio.create_task(self.download_source_image(new_download))) + # Nested because we don't try to download source mask if source image doesn't exist + if job_pop_response.source_mask is not None: + new_download = PendingSourceDownloadJob( + job_pop_response=job_pop_response, + field_name="source_mask", + download_url=job_pop_response.source_mask, + ) + download_tasks.append(asyncio.create_task(self.download_source_image(new_download))) + if job_pop_response.extra_source_images is not None: + for esi_index, esi in enumerate(job_pop_response.extra_source_images): + new_download = PendingSourceDownloadJob( + job_pop_response=job_pop_response, + field_name="extra_source_images", + download_url=esi.image, + esi_index=esi_index, + ) + download_tasks.append(asyncio.create_task(self.download_source_image(new_download))) - for field in image_fields: - try: - field_value = new_response_dict[field] - if field_value is not None: - image = PIL.Image.open(BytesIO(base64.b64decode(field_value))) - image.verify() + while len(download_tasks) > 0: + retry_submits: list[PendingSourceDownloadJob] = [] + results = await asyncio.gather(*download_tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + logger.exception(f"Exception in source image download task: {result}") + fault_metadata = GenMetadataEntry( + # When an exception occurs, we don't know what kind of source type was being download + # So we just report that the source image failed. + type=METADATA_TYPE.source_image, + value=METADATA_VALUE.download_failed, + ) + self.job_faults[job_pop_response.id_].append(fault_metadata) + elif isinstance(result, PendingSourceDownloadJob): + if not result.is_finished: + retry_submits.append(result) + else: + finished_download_tasks.append(result) + download_tasks = [] + for retry_submit in retry_submits: + download_tasks.append(asyncio.create_task(self.download_source_image(retry_submit))) - except Exception as e: - logger.error(f"Failed to verify {field}: {e}") - if job_pop_response.id_ is None: - raise ValueError("job_pop_response.id_ is None") from e + updated_souces = { + "source_image": job_pop_response.source_image, + "source_mask": job_pop_response.source_mask, + "extra_source_images": job_pop_response.extra_source_images, + "source_processing": job_pop_response.source_processing, + } - new_meta_entry = GenMetadataEntry( - type=METADATA_TYPE[field], - value=METADATA_VALUE.parse_failed, + for finished_download in finished_download_tasks: + if finished_download.is_faulted and finished_download.fault_metadata is not None: + self.job_faults[job_pop_response.id_].append(finished_download.fault_metadata) + logger.error(f"Failed to {finished_download.fault_metadata.value} on {new_download.log_reference}") + if finished_download.field_name == "source_image": + updated_souces["source_image"] = finished_download.image_b64 + if finished_download.field_name == "source_mask": + updated_souces["source_mask"] = finished_download.image_b64 + if ( + finished_download.field_name == "extra_source_images" + and job_pop_response.extra_source_images is not None + ): + upd_esi = updated_souces["extra_source_images"][finished_download.esi_index].model_copy( + update={ + "image": finished_download.image_b64, + }, ) - self.job_faults[job_pop_response.id_].append(new_meta_entry) - - new_response_dict[field] = None - new_response_dict["source_processing"] = "txt2img" - - return ImageGenerateJobPopResponse(**new_response_dict) + updated_souces["extra_source_images"][finished_download.esi_index] = upd_esi + if updated_souces["source_image"] is None: + updated_souces["source_processing"] = "text2img" + if updated_souces["extra_source_images"] is not None: + valid_extra_source_images: list[ExtraSourceImageEntry] = [] + for esi in updated_souces["extra_source_images"]: + if esi.image is not None: + valid_extra_source_images.append(esi) + updated_souces["extra_source_images"] = valid_extra_source_images + + return job_pop_response.model_copy(update=updated_souces) _last_pop_no_jobs_available: bool = False @@ -3618,7 +3729,6 @@ def timed_unset_recently_recovered() -> None: ) or ((now - self._last_job_submitted_time) > self.bridge_data.process_timeout) ) and not (self._last_pop_no_jobs_available or self._recently_recovered): - self._cleanup_jobs() if self.bridge_data.exit_on_unhandled_faults: diff --git a/horde_worker_regen/run_worker.py b/horde_worker_regen/run_worker.py index 68d9c6c9..9515a320 100644 --- a/horde_worker_regen/run_worker.py +++ b/horde_worker_regen/run_worker.py @@ -14,9 +14,9 @@ def main(ctx: BaseContext, load_from_env_vars: bool = False) -> None: """Check for a valid config and start the driver ('main') process for the reGen worker.""" - from horde_model_reference.model_reference_manager import ModelReferenceManager from pydantic import ValidationError + from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME from horde_worker_regen.process_management.main_entry_point import start_working diff --git a/pyproject.toml b/pyproject.toml index 172ffd71..4f06209a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,9 @@ exclude = ''' | build | dist | conda + | clip_blip + | hf_transformers + | horde_model_reference )/ ''' diff --git a/requirements.txt b/requirements.txt index a9992279..fc87abef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch>=2.1.2 -horde_sdk~=0.8.3 +horde_sdk~=0.9.0 horde_safety~=0.2.3 hordelib~=2.7.6 horde_model_reference~=0.6.3 diff --git a/tests/test_bridge_data.py b/tests/test_bridge_data.py index 8218e6f7..4a24551c 100644 --- a/tests/test_bridge_data.py +++ b/tests/test_bridge_data.py @@ -2,10 +2,10 @@ import pathlib import pytest -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.generic_api.consts import ANON_API_KEY from ruamel.yaml import YAML +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat From dd985592a3486fec9bc0a11018ee4b4d9bb997dd Mon Sep 17 00:00:00 2001 From: tazlin Date: Sun, 24 Mar 2024 22:16:30 -0400 Subject: [PATCH 05/14] feat/refactor: better source image handling Co-Authored-By: Divided by Zer0 --- .pre-commit-config.yaml | 4 +- horde-bridge.cmd | 2 +- horde_worker_regen/consts.py | 2 + .../process_management/process_manager.py | 188 +++++++----------- requirements.txt | 4 +- 5 files changed, 76 insertions(+), 124 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf79af8d..46f07728 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - horde_safety - torch - ruamel.yaml - - hordelib==2.7.6 - - horde_sdk==0.8.3 + - hordelib==2.8.0 + - horde_sdk==0.9.2 - horde_model_reference==0.6.3 - semver diff --git a/horde-bridge.cmd b/horde-bridge.cmd index d774dfe8..cd6b9a01 100644 --- a/horde-bridge.cmd +++ b/horde-bridge.cmd @@ -4,7 +4,7 @@ cd /d %~dp0 : This first call to runtime activates the environment for the rest of the script call runtime python -s -m pip -V -call python -s -m pip install horde_sdk~=0.9.0 horde_model_reference~=0.6.3 hordelib~=2.7.4 -U +call python -s -m pip install horde_sdk~=0.9.2 horde_model_reference~=0.6.3 hordelib~=2.8.0 -U if %ERRORLEVEL% NEQ 0 ( echo "Please run update-runtime.cmd." GOTO END diff --git a/horde_worker_regen/consts.py b/horde_worker_regen/consts.py index 613545c0..236614d3 100644 --- a/horde_worker_regen/consts.py +++ b/horde_worker_regen/consts.py @@ -15,3 +15,5 @@ MAX_LORAS = 5 TOTAL_LORA_DOWNLOAD_TIMEOUT = BASE_LORA_DOWNLOAD_TIMEOUT + (EXTRA_LORA_DOWNLOAD_TIMEOUT * MAX_LORAS) + +MAX_SOURCE_IMAGE_RETRIES = 5 diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index b175f0dd..caee1d08 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -18,6 +18,7 @@ from multiprocessing.context import BaseContext from multiprocessing.synchronize import Lock as Lock_MultiProcessing from multiprocessing.synchronize import Semaphore +from typing import override import aiohttp import aiohttp.client_exceptions @@ -49,7 +50,12 @@ from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader -from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME, KNOWN_SLOW_MODELS_DIFFICULTIES, VRAM_HEAVY_MODELS +from horde_worker_regen.consts import ( + BRIDGE_CONFIG_FILENAME, + KNOWN_SLOW_MODELS_DIFFICULTIES, + MAX_SOURCE_IMAGE_RETRIES, + VRAM_HEAVY_MODELS, +) from horde_worker_regen.process_management._aliased_types import ProcessQueue from horde_worker_regen.process_management.horde_process import HordeProcessType from horde_worker_regen.process_management.messages import ( @@ -767,6 +773,7 @@ def log_reference(self) -> str: log_reference = f"{self.field_name}_{self.esi_index}" return log_reference + @override def retry(self, metadata_value: METADATA_VALUE) -> None: """Marks this task to be retried. Adds GenMetadataEntry if retries exceeded.""" super().retry() @@ -777,6 +784,7 @@ def retry(self, metadata_value: METADATA_VALUE) -> None: ref=str(self.esi_index), ) + @override def fault(self, metadata_value: METADATA_VALUE) -> None: """Faults this task and adds a GenMetadataEntry.""" self.fault_metadata = GenMetadataEntry( @@ -819,6 +827,7 @@ def batch_count(self) -> int: """Return the number of jobs in the batch.""" return len(self.completed_job_info.sdk_api_job_info.ids) + @override def succeed(self, kudos_reward: int, kudos_per_second: float) -> None: """Mark the job as successfully submitted. @@ -948,7 +957,7 @@ def get_process_total_ram_usage(self) -> int: kudos_generated_this_session: float = 0 session_start_time: float = 0 - _aiohttp_session: aiohttp.ClientSession + _aiohttp_client_session: aiohttp.ClientSession stable_diffusion_reference: StableDiffusion_ModelReference | None horde_client: AIHordeAPIAsyncSimpleClient @@ -2301,7 +2310,7 @@ async def submit_single_generation(self, new_submit: PendingSubmitJob) -> Pendin new_submit.fault() return new_submit try: - async with self._aiohttp_session.put( + async with self._aiohttp_client_session.put( yarl.URL(new_submit.r2_upload, encoded=True), data=image_in_buffer.getvalue(), skip_auto_headers=["content-type"], @@ -2794,134 +2803,75 @@ def should_wait_for_pending_megapixelsteps(self) -> bool: return self.get_pending_megapixelsteps() > self._max_pending_megapixelsteps - async def download_source_image(self, new_download: PendingSourceDownloadJob) -> PendingSourceDownloadJob: - """Downloads a single source image asynchronously.""" - if new_download.download_url is not None and "https://" not in new_download.download_url: - new_download.fault(METADATA_VALUE.download_failed) - return new_download - logger.debug(f"Starting download of {new_download.log_reference}") - try: - # self.job_faults[job_pop_response.id_].append(new_meta_entry) - # logger.error(f"Failed to download {new_download.log_reference} after {fail_count} attempts") - response = await self._aiohttp_session.get( - new_download.download_url, - timeout=aiohttp.ClientTimeout(total=10), - ) - response.raise_for_status() - - content = await response.content.read() - new_download.image_b64 = base64.b64encode(content).decode("utf-8") - - logger.debug(f"Downloaded {new_download.log_reference} for job {new_download.job_pop_response.id_}") - except Exception as err: - logger.debug(f"{type(err)}: {err}") - logger.warning(f"Failed to download {new_download.log_reference}: {err}") - new_download.retry(METADATA_VALUE.download_failed) - return new_download - - try: - if new_download.image_b64 is not None: - image = PIL.Image.open(BytesIO(base64.b64decode(new_download.image_b64))) - image.verify() - except Exception as err: - logger.error(f"Failed to verify {new_download.log_reference}: {err}") - if new_download.job_pop_response.id_ is None: - raise ValueError("job_pop_response.id_ is None") from err - new_download.fault(METADATA_VALUE.parse_failed) - new_download.succeed() - return new_download - async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse) -> ImageGenerateJobPopResponse: # Adding this to stop mypy complaining if job_pop_response.id_ is None: logger.error("Received ImageGenerateJobPopResponse with id_ is None. Please let the devs know!") return job_pop_response - download_tasks: list[Task[PendingSourceDownloadJob]] = [] - finished_download_tasks: list[PendingSourceDownloadJob] = [] + download_tasks: list[Task] = [] - if job_pop_response.source_image is not None: - new_download = PendingSourceDownloadJob( - job_pop_response=job_pop_response, - field_name="source_image", - download_url=job_pop_response.source_image, - ) - download_tasks.append(asyncio.create_task(self.download_source_image(new_download))) - # Nested because we don't try to download source mask if source image doesn't exist - if job_pop_response.source_mask is not None: - new_download = PendingSourceDownloadJob( - job_pop_response=job_pop_response, - field_name="source_mask", - download_url=job_pop_response.source_mask, - ) - download_tasks.append(asyncio.create_task(self.download_source_image(new_download))) - if job_pop_response.extra_source_images is not None: - for esi_index, esi in enumerate(job_pop_response.extra_source_images): - new_download = PendingSourceDownloadJob( - job_pop_response=job_pop_response, - field_name="extra_source_images", - download_url=esi.image, - esi_index=esi_index, - ) - download_tasks.append(asyncio.create_task(self.download_source_image(new_download))) + source_image_is_url = False + if job_pop_response.source_image is not None and job_pop_response.source_image.startswith("http"): + source_image_is_url = True - while len(download_tasks) > 0: - retry_submits: list[PendingSourceDownloadJob] = [] - results = await asyncio.gather(*download_tasks, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - logger.exception(f"Exception in source image download task: {result}") - fault_metadata = GenMetadataEntry( - # When an exception occurs, we don't know what kind of source type was being download - # So we just report that the source image failed. - type=METADATA_TYPE.source_image, - value=METADATA_VALUE.download_failed, - ) - self.job_faults[job_pop_response.id_].append(fault_metadata) - elif isinstance(result, PendingSourceDownloadJob): - if not result.is_finished: - retry_submits.append(result) - else: - finished_download_tasks.append(result) - download_tasks = [] - for retry_submit in retry_submits: - download_tasks.append(asyncio.create_task(self.download_source_image(retry_submit))) + source_mask_is_url = False + if job_pop_response.source_mask is not None and job_pop_response.source_mask.startswith("http"): + source_mask_is_url = True - updated_souces = { - "source_image": job_pop_response.source_image, - "source_mask": job_pop_response.source_mask, - "extra_source_images": job_pop_response.extra_source_images, - "source_processing": job_pop_response.source_processing, - } + any_extra_source_images_are_urls = False + if job_pop_response.extra_source_images is not None: + for extra_source_image in job_pop_response.extra_source_images: + if extra_source_image.image.startswith("http"): + any_extra_source_images_are_urls = True + break + + attempts = 0 + while attempts < MAX_SOURCE_IMAGE_RETRIES: + if ( + source_image_is_url + and job_pop_response.source_image is not None + and job_pop_response.get_downloaded_source_image() is None + ): + download_tasks.append(job_pop_response.async_download_source_image(self._aiohttp_client_session)) + if ( + source_mask_is_url + and job_pop_response.source_mask is not None + and job_pop_response.get_downloaded_source_mask() is None + ): + download_tasks.append(job_pop_response.async_download_source_mask(self._aiohttp_client_session)) - for finished_download in finished_download_tasks: - if finished_download.is_faulted and finished_download.fault_metadata is not None: - self.job_faults[job_pop_response.id_].append(finished_download.fault_metadata) - logger.error(f"Failed to {finished_download.fault_metadata.value} on {new_download.log_reference}") - if finished_download.field_name == "source_image": - updated_souces["source_image"] = finished_download.image_b64 - if finished_download.field_name == "source_mask": - updated_souces["source_mask"] = finished_download.image_b64 + download_extra_source_images = job_pop_response.get_downloaded_extra_source_images() if ( - finished_download.field_name == "extra_source_images" + any_extra_source_images_are_urls and job_pop_response.extra_source_images is not None + or ( + download_extra_source_images is not None + and job_pop_response.extra_source_images is not None + and len(download_extra_source_images) != len(job_pop_response.extra_source_images) + ) ): - upd_esi = updated_souces["extra_source_images"][finished_download.esi_index].model_copy( - update={ - "image": finished_download.image_b64, - }, + + download_tasks.append( + asyncio.create_task( + job_pop_response.async_download_extra_source_images( + self._aiohttp_client_session, + max_retries=MAX_SOURCE_IMAGE_RETRIES, + ), + ), ) - updated_souces["extra_source_images"][finished_download.esi_index] = upd_esi - if updated_souces["source_image"] is None: - updated_souces["source_processing"] = "text2img" - if updated_souces["extra_source_images"] is not None: - valid_extra_source_images: list[ExtraSourceImageEntry] = [] - for esi in updated_souces["extra_source_images"]: - if esi.image is not None: - valid_extra_source_images.append(esi) - updated_souces["extra_source_images"] = valid_extra_source_images - - return job_pop_response.model_copy(update=updated_souces) + + gather_results = await asyncio.gather(*download_tasks, return_exceptions=True) + + for result in gather_results: + if isinstance(result, Exception): + logger.error(f"Failed to download source image: {result}") + attempts += 1 + break + else: + break + + return job_pop_response _last_pop_no_jobs_available: bool = False @@ -3259,8 +3209,8 @@ async def _job_submit_loop(self) -> None: async def _api_call_loop(self) -> None: """Run the API call loop for popping jobs and doing miscellaneous API calls.""" logger.debug("In _api_call_loop") - self._aiohttp_session = ClientSession(requote_redirect_url=False) - async with self._aiohttp_session as aiohttp_session: + self._aiohttp_client_session = ClientSession(requote_redirect_url=False) + async with self._aiohttp_client_session as aiohttp_session: self.horde_client_session = AIHordeAPIAsyncClientSession(aiohttp_session=aiohttp_session) self.horde_client = AIHordeAPIAsyncSimpleClient( aiohttp_session=None, diff --git a/requirements.txt b/requirements.txt index fc87abef..9d1d8703 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch>=2.1.2 -horde_sdk~=0.9.0 +horde_sdk~=0.9.2 horde_safety~=0.2.3 -hordelib~=2.7.6 +hordelib~=2.8.0 horde_model_reference~=0.6.3 python-dotenv From f518fb50ec5489b4cb6c64dd85997da272a1767a Mon Sep 17 00:00:00 2001 From: tazlin Date: Sun, 24 Mar 2024 22:28:49 -0400 Subject: [PATCH 06/14] fix: remove unneeded class, resolve style/lint issues style: fix --- convert_config_to_env.py | 1 + horde_worker_regen/bridge_data/load_config.py | 2 +- horde_worker_regen/download_models.py | 2 +- .../process_management/main_entry_point.py | 1 + .../process_management/process_manager.py | 53 +++---------------- horde_worker_regen/run_worker.py | 2 +- tests/test_bridge_data.py | 2 +- 7 files changed, 12 insertions(+), 51 deletions(-) diff --git a/convert_config_to_env.py b/convert_config_to_env.py index d6c738f4..c0712c60 100644 --- a/convert_config_to_env.py +++ b/convert_config_to_env.py @@ -12,6 +12,7 @@ import argparse from horde_model_reference.model_reference_manager import ModelReferenceManager + from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat diff --git a/horde_worker_regen/bridge_data/load_config.py b/horde_worker_regen/bridge_data/load_config.py index 56eced63..6da41d6d 100644 --- a/horde_worker_regen/bridge_data/load_config.py +++ b/horde_worker_regen/bridge_data/load_config.py @@ -6,13 +6,13 @@ from enum import auto from pathlib import Path +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient from horde_sdk.ai_horde_worker.model_meta import ImageModelLoadResolver from loguru import logger from ruamel.yaml import YAML from strenum import StrEnum -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data import AIWORKER_REGEN_PREFIX from horde_worker_regen.bridge_data.data_model import reGenBridgeData diff --git a/horde_worker_regen/download_models.py b/horde_worker_regen/download_models.py index a4b5cc2e..0f0cdbc1 100644 --- a/horde_worker_regen/download_models.py +++ b/horde_worker_regen/download_models.py @@ -12,9 +12,9 @@ def download_all_models( if not load_config_from_env_vars: load_env_vars_from_config() + from horde_model_reference.model_reference_manager import ModelReferenceManager from loguru import logger - from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME diff --git a/horde_worker_regen/process_management/main_entry_point.py b/horde_worker_regen/process_management/main_entry_point.py index 65d3cadc..73051216 100644 --- a/horde_worker_regen/process_management/main_entry_point.py +++ b/horde_worker_regen/process_management/main_entry_point.py @@ -1,6 +1,7 @@ from multiprocessing.context import BaseContext from horde_model_reference.model_reference_manager import ModelReferenceManager + from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.process_management.process_manager import HordeWorkerProcessManager diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index caee1d08..1834857c 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -18,7 +18,6 @@ from multiprocessing.context import BaseContext from multiprocessing.synchronize import Lock as Lock_MultiProcessing from multiprocessing.synchronize import Semaphore -from typing import override import aiohttp import aiohttp.client_exceptions @@ -27,11 +26,13 @@ import psutil import yarl from aiohttp import ClientSession +from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY +from horde_model_reference.model_reference_manager import ModelReferenceManager +from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_sdk import RequestErrorResponse from horde_sdk.ai_horde_api import GENERATION_STATE from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncClientSession, AIHordeAPIAsyncSimpleClient from horde_sdk.ai_horde_api.apimodels import ( - ExtraSourceImageEntry, FindUserRequest, FindUserResponse, GenMetadataEntry, @@ -43,11 +44,9 @@ from horde_sdk.ai_horde_api.fields import JobID from loguru import logger from pydantic import BaseModel, ConfigDict, RootModel, ValidationError +from typing_extensions import override import horde_worker_regen -from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY -from horde_model_reference.model_reference_manager import ModelReferenceManager -from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader from horde_worker_regen.consts import ( @@ -746,7 +745,7 @@ def retry(self) -> None: if self._consecutive_failed_job_submits > self._max_consecutive_failed_job_submits: self.state = JobSubmitState.FAULTED - def succeed(self) -> None: + def succeed(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 """Mark the job as successfully submitted.""" self.state = JobSubmitState.SUCCESS @@ -755,46 +754,6 @@ def fault(self) -> None: self.state = JobSubmitState.FAULTED -class PendingSourceDownloadJob(PendingJob): - """Information about a source image to download from the horde.""" - - job_pop_response: ImageGenerateJobPopResponse - field_name: str - download_url: str | None - esi_index: int | None = None - image_b64: str | None = None - fault_metadata: GenMetadataEntry | None = None - - @property - def log_reference(self) -> str: - """Returns a string identifying the source image for logs.""" - log_reference = self.field_name - if self.esi_index is not None: - log_reference = f"{self.field_name}_{self.esi_index}" - return log_reference - - @override - def retry(self, metadata_value: METADATA_VALUE) -> None: - """Marks this task to be retried. Adds GenMetadataEntry if retries exceeded.""" - super().retry() - if self.is_faulted: - self.fault_metadata = GenMetadataEntry( - type=METADATA_TYPE[self.field_name], - value=metadata_value, - ref=str(self.esi_index), - ) - - @override - def fault(self, metadata_value: METADATA_VALUE) -> None: - """Faults this task and adds a GenMetadataEntry.""" - self.fault_metadata = GenMetadataEntry( - type=METADATA_TYPE[self.field_name], - value=metadata_value, - ref=str(self.esi_index), - ) - super().fault() - - class PendingSubmitJob(PendingJob): # TODO: Split into a new file """Information about a job to submit to the horde.""" @@ -828,7 +787,7 @@ def batch_count(self) -> int: return len(self.completed_job_info.sdk_api_job_info.ids) @override - def succeed(self, kudos_reward: int, kudos_per_second: float) -> None: + def succeed(self, kudos_reward: int = 0, kudos_per_second: float = 0) -> None: """Mark the job as successfully submitted. Args: diff --git a/horde_worker_regen/run_worker.py b/horde_worker_regen/run_worker.py index 9515a320..68d9c6c9 100644 --- a/horde_worker_regen/run_worker.py +++ b/horde_worker_regen/run_worker.py @@ -14,9 +14,9 @@ def main(ctx: BaseContext, load_from_env_vars: bool = False) -> None: """Check for a valid config and start the driver ('main') process for the reGen worker.""" + from horde_model_reference.model_reference_manager import ModelReferenceManager from pydantic import ValidationError - from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME from horde_worker_regen.process_management.main_entry_point import start_working diff --git a/tests/test_bridge_data.py b/tests/test_bridge_data.py index 4a24551c..8218e6f7 100644 --- a/tests/test_bridge_data.py +++ b/tests/test_bridge_data.py @@ -2,10 +2,10 @@ import pathlib import pytest +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.generic_api.consts import ANON_API_KEY from ruamel.yaml import YAML -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat From 1d4020bbc36743842d8da0cf13b1c7fc830b6af0 Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 25 Mar 2024 07:47:48 -0400 Subject: [PATCH 07/14] fix: re-add support for GenMetadata source image faults --- .../process_management/process_manager.py | 67 ++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 1834857c..8e3fb23e 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -2773,17 +2773,19 @@ async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse source_image_is_url = False if job_pop_response.source_image is not None and job_pop_response.source_image.startswith("http"): source_image_is_url = True + logger.debug(f"Source image for job {job_pop_response.id_} is a URL") source_mask_is_url = False if job_pop_response.source_mask is not None and job_pop_response.source_mask.startswith("http"): source_mask_is_url = True + logger.debug(f"Source mask for job {job_pop_response.id_} is a URL") any_extra_source_images_are_urls = False if job_pop_response.extra_source_images is not None: for extra_source_image in job_pop_response.extra_source_images: if extra_source_image.image.startswith("http"): any_extra_source_images_are_urls = True - break + logger.debug(f"Extra source image for job {job_pop_response.id_} is a URL") attempts = 0 while attempts < MAX_SOURCE_IMAGE_RETRIES: @@ -2830,6 +2832,69 @@ async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse else: break + if attempts >= MAX_SOURCE_IMAGE_RETRIES: + if source_image_is_url and job_pop_response.get_downloaded_source_image() is None: + if self.job_faults.get(job_pop_response.id_) is None: + self.job_faults[job_pop_response.id_] = [] + + logger.error(f"Failed to download source image for job {job_pop_response.id_}") + self.job_faults[job_pop_response.id_].append( + GenMetadataEntry( + type=METADATA_TYPE.source_image, + value=METADATA_VALUE.download_failed, + ref="source_image", + ), + ) + + if source_mask_is_url and job_pop_response.get_downloaded_source_mask() is None: + if self.job_faults.get(job_pop_response.id_) is None: + self.job_faults[job_pop_response.id_] = [] + logger.error(f"Failed to download source mask for job {job_pop_response.id_}") + + self.job_faults[job_pop_response.id_].append( + GenMetadataEntry( + type=METADATA_TYPE.source_mask, + value=METADATA_VALUE.download_failed, + ref="source_mask", + ), + ) + downloaded_extra_source_images = job_pop_response.get_downloaded_extra_source_images() + if ( + any_extra_source_images_are_urls + and downloaded_extra_source_images is None + or ( + downloaded_extra_source_images is not None + and job_pop_response.extra_source_images is not None + and len(downloaded_extra_source_images) != len(job_pop_response.extra_source_images) + ) + ): + if self.job_faults.get(job_pop_response.id_) is None: + self.job_faults[job_pop_response.id_] = [] + logger.error(f"Failed to download extra source images for job {job_pop_response.id_}") + + ref = [] + if job_pop_response.extra_source_images is not None and downloaded_extra_source_images is not None: + for predownload_extra_source_image in job_pop_response.extra_source_images: + if predownload_extra_source_image.image.startswith("http"): + if any( + predownload_extra_source_image.original_url == extra_source_image.image + for extra_source_image in downloaded_extra_source_images + ): + continue + + ref.append(str(job_pop_response.extra_source_images.index(predownload_extra_source_image))) + elif job_pop_response.extra_source_images is not None and downloaded_extra_source_images is None: + ref = [str(i) for i in range(len(job_pop_response.extra_source_images))] + + for r in ref: + self.job_faults[job_pop_response.id_].append( + GenMetadataEntry( + type=METADATA_TYPE.extra_source_images, + value=METADATA_VALUE.download_failed, + ref=r, + ), + ) + return job_pop_response _last_pop_no_jobs_available: bool = False From 040831f8e45b459559ba8bfd1b6133e46882bc0e Mon Sep 17 00:00:00 2001 From: db0 Date: Wed, 27 Mar 2024 01:40:09 +0100 Subject: [PATCH 08/14] feat: record source sizes and esi count for training model --- .../process_management/process_manager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 8e3fb23e..10cb334d 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -103,6 +103,7 @@ "skipped": ..., "source_image": ..., "source_mask": ..., + "extra_source_images": ..., "r2_upload": ..., "r2_uploads": ..., }, @@ -2584,6 +2585,17 @@ async def api_submit_job(self) -> None: model_dump["sdk_api_job_info"]["payload"]["ti_count"] = len( model_dump["sdk_api_job_info"]["payload"]["tis"], ) + model_dump["sdk_api_job_info"]["extra_source_images_count"] = len( + hji.sdk_api_job_info.extra_source_images) if hji.sdk_api_job_info.extra_source_images else 0 + esi_combined_size = 0 + if hji.sdk_api_job_info.extra_source_images: + for esi in hji.sdk_api_job_info.extra_source_images: + esi_combined_size += len(esi.image) + model_dump["sdk_api_job_info"]["extra_source_images_combined_size"] = esi_combined_size + model_dump["sdk_api_job_info"]["source_image_size"] = len( + hji.sdk_api_job_info.source_image) if hji.sdk_api_job_info.source_image else 0 + model_dump["sdk_api_job_info"]["source_mask_size"] = len( + hji.sdk_api_job_info.source_mask) if hji.sdk_api_job_info.source_mask else 0 if not os.path.exists(file_name_to_use): with open(file_name_to_use, "w") as f: json.dump([model_dump], f, indent=4) From 37cf38278fb4c052fce8315c06ccf9859e9933a0 Mon Sep 17 00:00:00 2001 From: tazlin Date: Wed, 27 Mar 2024 08:36:08 -0400 Subject: [PATCH 09/14] style: fix --- .../process_management/process_manager.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 10cb334d..b49a4341 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -2585,17 +2585,22 @@ async def api_submit_job(self) -> None: model_dump["sdk_api_job_info"]["payload"]["ti_count"] = len( model_dump["sdk_api_job_info"]["payload"]["tis"], ) - model_dump["sdk_api_job_info"]["extra_source_images_count"] = len( - hji.sdk_api_job_info.extra_source_images) if hji.sdk_api_job_info.extra_source_images else 0 + model_dump["sdk_api_job_info"]["extra_source_images_count"] = ( + len(hji.sdk_api_job_info.extra_source_images) + if hji.sdk_api_job_info.extra_source_images + else 0 + ) esi_combined_size = 0 if hji.sdk_api_job_info.extra_source_images: for esi in hji.sdk_api_job_info.extra_source_images: esi_combined_size += len(esi.image) model_dump["sdk_api_job_info"]["extra_source_images_combined_size"] = esi_combined_size - model_dump["sdk_api_job_info"]["source_image_size"] = len( - hji.sdk_api_job_info.source_image) if hji.sdk_api_job_info.source_image else 0 - model_dump["sdk_api_job_info"]["source_mask_size"] = len( - hji.sdk_api_job_info.source_mask) if hji.sdk_api_job_info.source_mask else 0 + model_dump["sdk_api_job_info"]["source_image_size"] = ( + len(hji.sdk_api_job_info.source_image) if hji.sdk_api_job_info.source_image else 0 + ) + model_dump["sdk_api_job_info"]["source_mask_size"] = ( + len(hji.sdk_api_job_info.source_mask) if hji.sdk_api_job_info.source_mask else 0 + ) if not os.path.exists(file_name_to_use): with open(file_name_to_use, "w") as f: json.dump([model_dump], f, indent=4) From 7b8a846d2aacdb64581db12338514df8f4617772 Mon Sep 17 00:00:00 2001 From: tazlin Date: Wed, 27 Mar 2024 10:57:49 -0400 Subject: [PATCH 10/14] chore: version bump --- horde_worker_regen/__init__.py | 2 +- horde_worker_regen/_version_meta.json | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/horde_worker_regen/__init__.py b/horde_worker_regen/__init__.py index b4009f49..d9ad0a0a 100644 --- a/horde_worker_regen/__init__.py +++ b/horde_worker_regen/__init__.py @@ -8,4 +8,4 @@ ASSETS_FOLDER_PATH = Path(__file__).parent / "assets" -__version__ = "5.0.0" +__version__ = "5.0.1" diff --git a/horde_worker_regen/_version_meta.json b/horde_worker_regen/_version_meta.json index 628f0bac..45176791 100644 --- a/horde_worker_regen/_version_meta.json +++ b/horde_worker_regen/_version_meta.json @@ -1,5 +1,5 @@ { - "recommended_version": "4.3.9", + "recommended_version": "5.0.1", "required_min_version": "4.2.7", "required_min_version_update_date": "2024-03-09", "required_min_version_info": { diff --git a/pyproject.toml b/pyproject.toml index 4f06209a..e0126822 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "horde_worker_regen" -version = "4.3.9" +version = "5.0.1" description = "Allows you to connect to the AI Horde and generate images for users." authors = [ {name = "tazlin", email = "tazlin.on.github@gmail.com"}, From 15a75e09b025c7d6545af87adf18024377b83fa3 Mon Sep 17 00:00:00 2001 From: db0 Date: Thu, 28 Mar 2024 12:26:58 +0100 Subject: [PATCH 11/14] feat: switch to newer python and cuda on conda --- environment.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yaml b/environment.yaml index 4be2a4b0..4e269137 100644 --- a/environment.yaml +++ b/environment.yaml @@ -5,7 +5,7 @@ channels: - defaults # These should only contain the minimal essentials to get the binaries going, everything else is managed in requirements.txt to keep it universal. dependencies: - - cudatoolkit==11.8.0 + - nvidia::cuda-toolkit - git - pip - - python==3.10 + - python==3.10.12 From 5fee8322d6517c400764921f80ec00393aad08df Mon Sep 17 00:00:00 2001 From: db0 Date: Thu, 28 Mar 2024 18:53:16 +0100 Subject: [PATCH 12/14] chore: update hordelib version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9d1d8703..9ebaedd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch>=2.1.2 horde_sdk~=0.9.2 horde_safety~=0.2.3 -hordelib~=2.8.0 +hordelib~=2.8.1 horde_model_reference~=0.6.3 python-dotenv From 4cac407d61cdc29d87ded2e1acec1e1674eb90ae Mon Sep 17 00:00:00 2001 From: db0 Date: Thu, 28 Mar 2024 19:12:42 +0100 Subject: [PATCH 13/14] style: linting --- .pre-commit-config.yaml | 2 +- convert_config_to_env.py | 1 - horde-bridge.cmd | 2 +- horde_worker_regen/bridge_data/load_config.py | 2 +- horde_worker_regen/download_models.py | 2 +- horde_worker_regen/process_management/main_entry_point.py | 1 - horde_worker_regen/process_management/process_manager.py | 6 +++--- horde_worker_regen/run_worker.py | 2 +- tests/test_bridge_data.py | 2 +- 9 files changed, 9 insertions(+), 11 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46f07728..4a62c5a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - horde_safety - torch - ruamel.yaml - - hordelib==2.8.0 + - hordelib==2.8.1 - horde_sdk==0.9.2 - horde_model_reference==0.6.3 - semver diff --git a/convert_config_to_env.py b/convert_config_to_env.py index c0712c60..d6c738f4 100644 --- a/convert_config_to_env.py +++ b/convert_config_to_env.py @@ -12,7 +12,6 @@ import argparse from horde_model_reference.model_reference_manager import ModelReferenceManager - from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat diff --git a/horde-bridge.cmd b/horde-bridge.cmd index cd6b9a01..242514ff 100644 --- a/horde-bridge.cmd +++ b/horde-bridge.cmd @@ -4,7 +4,7 @@ cd /d %~dp0 : This first call to runtime activates the environment for the rest of the script call runtime python -s -m pip -V -call python -s -m pip install horde_sdk~=0.9.2 horde_model_reference~=0.6.3 hordelib~=2.8.0 -U +call python -s -m pip install horde_sdk~=0.9.2 horde_model_reference~=0.6.3 hordelib~=2.8.1 -U if %ERRORLEVEL% NEQ 0 ( echo "Please run update-runtime.cmd." GOTO END diff --git a/horde_worker_regen/bridge_data/load_config.py b/horde_worker_regen/bridge_data/load_config.py index 6da41d6d..56eced63 100644 --- a/horde_worker_regen/bridge_data/load_config.py +++ b/horde_worker_regen/bridge_data/load_config.py @@ -6,13 +6,13 @@ from enum import auto from pathlib import Path -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient from horde_sdk.ai_horde_worker.model_meta import ImageModelLoadResolver from loguru import logger from ruamel.yaml import YAML from strenum import StrEnum +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data import AIWORKER_REGEN_PREFIX from horde_worker_regen.bridge_data.data_model import reGenBridgeData diff --git a/horde_worker_regen/download_models.py b/horde_worker_regen/download_models.py index 0f0cdbc1..a4b5cc2e 100644 --- a/horde_worker_regen/download_models.py +++ b/horde_worker_regen/download_models.py @@ -12,9 +12,9 @@ def download_all_models( if not load_config_from_env_vars: load_env_vars_from_config() - from horde_model_reference.model_reference_manager import ModelReferenceManager from loguru import logger + from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME diff --git a/horde_worker_regen/process_management/main_entry_point.py b/horde_worker_regen/process_management/main_entry_point.py index 73051216..65d3cadc 100644 --- a/horde_worker_regen/process_management/main_entry_point.py +++ b/horde_worker_regen/process_management/main_entry_point.py @@ -1,7 +1,6 @@ from multiprocessing.context import BaseContext from horde_model_reference.model_reference_manager import ModelReferenceManager - from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.process_management.process_manager import HordeWorkerProcessManager diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index b49a4341..f7f6633d 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -26,9 +26,6 @@ import psutil import yarl from aiohttp import ClientSession -from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY -from horde_model_reference.model_reference_manager import ModelReferenceManager -from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_sdk import RequestErrorResponse from horde_sdk.ai_horde_api import GENERATION_STATE from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncClientSession, AIHordeAPIAsyncSimpleClient @@ -47,6 +44,9 @@ from typing_extensions import override import horde_worker_regen +from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY +from horde_model_reference.model_reference_manager import ModelReferenceManager +from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader from horde_worker_regen.consts import ( diff --git a/horde_worker_regen/run_worker.py b/horde_worker_regen/run_worker.py index 68d9c6c9..9515a320 100644 --- a/horde_worker_regen/run_worker.py +++ b/horde_worker_regen/run_worker.py @@ -14,9 +14,9 @@ def main(ctx: BaseContext, load_from_env_vars: bool = False) -> None: """Check for a valid config and start the driver ('main') process for the reGen worker.""" - from horde_model_reference.model_reference_manager import ModelReferenceManager from pydantic import ValidationError + from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME from horde_worker_regen.process_management.main_entry_point import start_working diff --git a/tests/test_bridge_data.py b/tests/test_bridge_data.py index 8218e6f7..4a24551c 100644 --- a/tests/test_bridge_data.py +++ b/tests/test_bridge_data.py @@ -2,10 +2,10 @@ import pathlib import pytest -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.generic_api.consts import ANON_API_KEY from ruamel.yaml import YAML +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat From e92a35584fad9429249521622e656f15c76c1c08 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 29 Mar 2024 13:11:43 -0400 Subject: [PATCH 14/14] style: fix --- convert_config_to_env.py | 1 + horde_worker_regen/bridge_data/load_config.py | 2 +- horde_worker_regen/download_models.py | 2 +- horde_worker_regen/process_management/main_entry_point.py | 1 + horde_worker_regen/process_management/process_manager.py | 6 +++--- horde_worker_regen/run_worker.py | 2 +- tests/test_bridge_data.py | 2 +- 7 files changed, 9 insertions(+), 7 deletions(-) diff --git a/convert_config_to_env.py b/convert_config_to_env.py index d6c738f4..c0712c60 100644 --- a/convert_config_to_env.py +++ b/convert_config_to_env.py @@ -12,6 +12,7 @@ import argparse from horde_model_reference.model_reference_manager import ModelReferenceManager + from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat diff --git a/horde_worker_regen/bridge_data/load_config.py b/horde_worker_regen/bridge_data/load_config.py index 56eced63..6da41d6d 100644 --- a/horde_worker_regen/bridge_data/load_config.py +++ b/horde_worker_regen/bridge_data/load_config.py @@ -6,13 +6,13 @@ from enum import auto from pathlib import Path +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIManualClient from horde_sdk.ai_horde_worker.model_meta import ImageModelLoadResolver from loguru import logger from ruamel.yaml import YAML from strenum import StrEnum -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data import AIWORKER_REGEN_PREFIX from horde_worker_regen.bridge_data.data_model import reGenBridgeData diff --git a/horde_worker_regen/download_models.py b/horde_worker_regen/download_models.py index a4b5cc2e..0f0cdbc1 100644 --- a/horde_worker_regen/download_models.py +++ b/horde_worker_regen/download_models.py @@ -12,9 +12,9 @@ def download_all_models( if not load_config_from_env_vars: load_env_vars_from_config() + from horde_model_reference.model_reference_manager import ModelReferenceManager from loguru import logger - from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME diff --git a/horde_worker_regen/process_management/main_entry_point.py b/horde_worker_regen/process_management/main_entry_point.py index 65d3cadc..73051216 100644 --- a/horde_worker_regen/process_management/main_entry_point.py +++ b/horde_worker_regen/process_management/main_entry_point.py @@ -1,6 +1,7 @@ from multiprocessing.context import BaseContext from horde_model_reference.model_reference_manager import ModelReferenceManager + from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.process_management.process_manager import HordeWorkerProcessManager diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index f7f6633d..b49a4341 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -26,6 +26,9 @@ import psutil import yarl from aiohttp import ClientSession +from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY +from horde_model_reference.model_reference_manager import ModelReferenceManager +from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_sdk import RequestErrorResponse from horde_sdk.ai_horde_api import GENERATION_STATE from horde_sdk.ai_horde_api.ai_horde_clients import AIHordeAPIAsyncClientSession, AIHordeAPIAsyncSimpleClient @@ -44,9 +47,6 @@ from typing_extensions import override import horde_worker_regen -from horde_model_reference.meta_consts import MODEL_REFERENCE_CATEGORY, STABLE_DIFFUSION_BASELINE_CATEGORY -from horde_model_reference.model_reference_manager import ModelReferenceManager -from horde_model_reference.model_reference_records import StableDiffusion_ModelReference from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader from horde_worker_regen.consts import ( diff --git a/horde_worker_regen/run_worker.py b/horde_worker_regen/run_worker.py index 9515a320..68d9c6c9 100644 --- a/horde_worker_regen/run_worker.py +++ b/horde_worker_regen/run_worker.py @@ -14,9 +14,9 @@ def main(ctx: BaseContext, load_from_env_vars: bool = False) -> None: """Check for a valid config and start the driver ('main') process for the reGen worker.""" + from horde_model_reference.model_reference_manager import ModelReferenceManager from pydantic import ValidationError - from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, reGenBridgeData from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME from horde_worker_regen.process_management.main_entry_point import start_working diff --git a/tests/test_bridge_data.py b/tests/test_bridge_data.py index 4a24551c..8218e6f7 100644 --- a/tests/test_bridge_data.py +++ b/tests/test_bridge_data.py @@ -2,10 +2,10 @@ import pathlib import pytest +from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_sdk.generic_api.consts import ANON_API_KEY from ruamel.yaml import YAML -from horde_model_reference.model_reference_manager import ModelReferenceManager from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader, ConfigFormat