diff --git a/.gitignore b/.gitignore index c72168e0..a1e4f4fb 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,6 @@ cython_debug/ bin/* conda/* models/* +clip_blip/* +hf_transformers/* +horde_model_reference/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf79af8d..4a62c5a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: - horde_safety - torch - ruamel.yaml - - hordelib==2.7.6 - - horde_sdk==0.8.3 + - hordelib==2.8.1 + - horde_sdk==0.9.2 - horde_model_reference==0.6.3 - semver diff --git a/environment.yaml b/environment.yaml index 4be2a4b0..4e269137 100644 --- a/environment.yaml +++ b/environment.yaml @@ -5,7 +5,7 @@ channels: - defaults # These should only contain the minimal essentials to get the binaries going, everything else is managed in requirements.txt to keep it universal. dependencies: - - cudatoolkit==11.8.0 + - nvidia::cuda-toolkit - git - pip - - python==3.10 + - python==3.10.12 diff --git a/horde-bridge.cmd b/horde-bridge.cmd index 6d41a9c8..242514ff 100644 --- a/horde-bridge.cmd +++ b/horde-bridge.cmd @@ -4,7 +4,7 @@ cd /d %~dp0 : This first call to runtime activates the environment for the rest of the script call runtime python -s -m pip -V -call python -s -m pip install horde_sdk~=0.8.3 horde_model_reference~=0.6.3 hordelib~=2.7.6 -U +call python -s -m pip install horde_sdk~=0.9.2 horde_model_reference~=0.6.3 hordelib~=2.8.1 -U if %ERRORLEVEL% NEQ 0 ( echo "Please run update-runtime.cmd." GOTO END diff --git a/horde_worker_regen/__init__.py b/horde_worker_regen/__init__.py index bc32109d..d9ad0a0a 100644 --- a/horde_worker_regen/__init__.py +++ b/horde_worker_regen/__init__.py @@ -8,4 +8,4 @@ ASSETS_FOLDER_PATH = Path(__file__).parent / "assets" -__version__ = "4.3.9" +__version__ = "5.0.1" diff --git a/horde_worker_regen/_version_meta.json b/horde_worker_regen/_version_meta.json index 628f0bac..45176791 100644 --- a/horde_worker_regen/_version_meta.json +++ b/horde_worker_regen/_version_meta.json @@ -1,5 +1,5 @@ { - "recommended_version": "4.3.9", + "recommended_version": "5.0.1", "required_min_version": "4.2.7", "required_min_version_update_date": "2024-03-09", "required_min_version_info": { diff --git a/horde_worker_regen/consts.py b/horde_worker_regen/consts.py index 613545c0..236614d3 100644 --- a/horde_worker_regen/consts.py +++ b/horde_worker_regen/consts.py @@ -15,3 +15,5 @@ MAX_LORAS = 5 TOTAL_LORA_DOWNLOAD_TIMEOUT = BASE_LORA_DOWNLOAD_TIMEOUT + (EXTRA_LORA_DOWNLOAD_TIMEOUT * MAX_LORAS) + +MAX_SOURCE_IMAGE_RETRIES = 5 diff --git a/horde_worker_regen/process_management/horde_process.py b/horde_worker_regen/process_management/horde_process.py index 493a8da6..cd28410f 100644 --- a/horde_worker_regen/process_management/horde_process.py +++ b/horde_worker_regen/process_management/horde_process.py @@ -140,14 +140,17 @@ def send_process_state_change_message( _heartbeat_limit_interval_seconds: float = 1.0 _last_heartbeat_time: float = 0.0 + _last_heartbeat_type: HordeHeartbeatType = HordeHeartbeatType.OTHER def send_heartbeat_message(self, heartbeat_type: HordeHeartbeatType) -> None: """Send a heartbeat message to the main process, indicating that the process is still alive. Note that this will only send a heartbeat message if the last heartbeat was sent more than - `_heartbeat_limit_interval_seconds` ago. + `_heartbeat_limit_interval_seconds` ago or if the heartbeat type has changed. """ - if (time.time() - self._last_heartbeat_time) < self._heartbeat_limit_interval_seconds: # FIXME? + if (heartbeat_type != self._last_heartbeat_type) and ( + time.time() - self._last_heartbeat_time + ) < self._heartbeat_limit_interval_seconds: return message = HordeProcessHeartbeatMessage( @@ -158,6 +161,7 @@ def send_heartbeat_message(self, heartbeat_type: HordeHeartbeatType) -> None: ) self.process_message_queue.put(message) + self._last_heartbeat_type = heartbeat_type self._last_heartbeat_time = time.time() @abstractmethod diff --git a/horde_worker_regen/process_management/inference_process.py b/horde_worker_regen/process_management/inference_process.py index 7435bbaf..1b33fb9c 100644 --- a/horde_worker_regen/process_management/inference_process.py +++ b/horde_worker_regen/process_management/inference_process.py @@ -431,7 +431,7 @@ def progress_callback( except Exception as e: logger.error(f"Failed to release inference semaphore: {type(e).__name__} {e}") - if progress_report.comfyui_progress is not None and progress_report.comfyui_progress.current_step >= 0: + if progress_report.comfyui_progress is not None and progress_report.comfyui_progress.current_step > 0: self.send_heartbeat_message(heartbeat_type=HordeHeartbeatType.INFERENCE_STEP) else: self.send_heartbeat_message(heartbeat_type=HordeHeartbeatType.PIPELINE_STATE_CHANGE) @@ -451,9 +451,11 @@ def start_inference(self, job_info: ImageGenerateJobPopResponse) -> list[Resulti self._is_busy = True try: logger.info(f"Starting inference for job(s) {job_info.ids}") + esi_count = len(job_info.extra_source_images) if job_info.extra_source_images is not None else 0 logger.debug( - f"has source_image: {job_info.source_image is not None} " - f"has source_mask: {job_info.source_mask is not None}", + f"has source_image: {job_info.source_image is not None}, " + f"has source_mask: {job_info.source_mask is not None}, " + f"extra_source_images: {esi_count}", ) logger.debug(f"{job_info.payload.model_dump(exclude={'prompt'})}") diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index acbe14c9..b49a4341 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -44,12 +44,17 @@ from horde_sdk.ai_horde_api.fields import JobID from loguru import logger from pydantic import BaseModel, ConfigDict, RootModel, ValidationError -from typing_extensions import Any +from typing_extensions import override import horde_worker_regen from horde_worker_regen.bridge_data.data_model import reGenBridgeData from horde_worker_regen.bridge_data.load_config import BridgeDataLoader -from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME, KNOWN_SLOW_MODELS_DIFFICULTIES, VRAM_HEAVY_MODELS +from horde_worker_regen.consts import ( + BRIDGE_CONFIG_FILENAME, + KNOWN_SLOW_MODELS_DIFFICULTIES, + MAX_SOURCE_IMAGE_RETRIES, + VRAM_HEAVY_MODELS, +) from horde_worker_regen.process_management._aliased_types import ProcessQueue from horde_worker_regen.process_management.horde_process import HordeProcessType from horde_worker_regen.process_management.messages import ( @@ -98,6 +103,7 @@ "skipped": ..., "source_image": ..., "source_mask": ..., + "extra_source_images": ..., "r2_upload": ..., "r2_uploads": ..., }, @@ -454,7 +460,7 @@ def is_stuck_on_inference(self, process_id: int) -> bool: return False if ( self[process_id].last_heartbeat_type == HordeHeartbeatType.INFERENCE_STEP - and (time.time() - self[process_id].last_heartbeat_timestamp) > 30 + and (time.time() - self[process_id].last_heartbeat_timestamp) > 45 ): return True return False @@ -712,16 +718,50 @@ class JobSubmitState(enum.Enum): # TODO: Split into a new file """The job submit faulted for some reason.""" -class PendingSubmitJob(BaseModel): # TODO: Split into a new file +class PendingJob(BaseModel): + """Base class for all PendingJobs async tasks.""" + + state: JobSubmitState = JobSubmitState.PENDING + _max_consecutive_failed_job_submits: int = 10 + _consecutive_failed_job_submits: int = 0 + + @property + def is_finished(self) -> bool: + """Return true if the job submit has finished.""" + return self.state != JobSubmitState.PENDING + + @property + def is_faulted(self) -> bool: + """Return true if the job submit has faulted.""" + return self.state == JobSubmitState.FAULTED + + @property + def retry_attempts_string(self) -> str: + """Return a string containing the number of consecutive failed job submits and the maximum allowed.""" + return f"{self._consecutive_failed_job_submits}/{self._max_consecutive_failed_job_submits}" + + def retry(self) -> None: + """Mark the job as needing to be retried. Fault the job if it has been retried too many times.""" + self._consecutive_failed_job_submits += 1 + if self._consecutive_failed_job_submits > self._max_consecutive_failed_job_submits: + self.state = JobSubmitState.FAULTED + + def succeed(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 + """Mark the job as successfully submitted.""" + self.state = JobSubmitState.SUCCESS + + def fault(self) -> None: + """Mark the job as faulted.""" + self.state = JobSubmitState.FAULTED + + +class PendingSubmitJob(PendingJob): # TODO: Split into a new file """Information about a job to submit to the horde.""" completed_job_info: HordeJobInfo gen_iter: int - state: JobSubmitState = JobSubmitState.PENDING kudos_reward: int = 0 kudos_per_second: float = 0.0 - _max_consecutive_failed_job_submits: int = 10 - _consecutive_failed_job_submits: int = 0 @property def image_result(self) -> HordeImageResult | None: @@ -742,33 +782,13 @@ def r2_upload(self) -> str: return "" # FIXME: Is this ever None? Or just a bad declaration on sdk? return self.completed_job_info.sdk_api_job_info.r2_uploads[self.gen_iter] - @property - def is_finished(self) -> bool: - """Return true if the job submit has finished.""" - return self.state != JobSubmitState.PENDING - - @property - def is_faulted(self) -> bool: - """Return true if the job submit has faulted.""" - return self.state == JobSubmitState.FAULTED - - @property - def retry_attempts_string(self) -> str: - """Return a string containing the number of consecutive failed job submits and the maximum allowed.""" - return f"{self._consecutive_failed_job_submits}/{self._max_consecutive_failed_job_submits}" - @property def batch_count(self) -> int: """Return the number of jobs in the batch.""" return len(self.completed_job_info.sdk_api_job_info.ids) - def retry(self) -> None: - """Mark the job as needing to be retried. Fault the job if it has been retried too many times.""" - self._consecutive_failed_job_submits += 1 - if self._consecutive_failed_job_submits > self._max_consecutive_failed_job_submits: - self.state = JobSubmitState.FAULTED - - def succeed(self, kudos_reward: int, kudos_per_second: float) -> None: + @override + def succeed(self, kudos_reward: int = 0, kudos_per_second: float = 0) -> None: """Mark the job as successfully submitted. Args: @@ -777,11 +797,7 @@ def succeed(self, kudos_reward: int, kudos_per_second: float) -> None: """ self.kudos_reward = kudos_reward self.kudos_per_second = kudos_per_second - self.state = JobSubmitState.SUCCESS - - def fault(self) -> None: - """Mark the job as faulted.""" - self.state = JobSubmitState.FAULTED + super().succeed() class NextJobAndProcess(BaseModel): @@ -901,7 +917,7 @@ def get_process_total_ram_usage(self) -> int: kudos_generated_this_session: float = 0 session_start_time: float = 0 - _aiohttp_session: aiohttp.ClientSession + _aiohttp_client_session: aiohttp.ClientSession stable_diffusion_reference: StableDiffusion_ModelReference | None horde_client: AIHordeAPIAsyncSimpleClient @@ -2254,7 +2270,7 @@ async def submit_single_generation(self, new_submit: PendingSubmitJob) -> Pendin new_submit.fault() return new_submit try: - async with self._aiohttp_session.put( + async with self._aiohttp_client_session.put( yarl.URL(new_submit.r2_upload, encoded=True), data=image_in_buffer.getvalue(), skip_auto_headers=["content-type"], @@ -2569,6 +2585,22 @@ async def api_submit_job(self) -> None: model_dump["sdk_api_job_info"]["payload"]["ti_count"] = len( model_dump["sdk_api_job_info"]["payload"]["tis"], ) + model_dump["sdk_api_job_info"]["extra_source_images_count"] = ( + len(hji.sdk_api_job_info.extra_source_images) + if hji.sdk_api_job_info.extra_source_images + else 0 + ) + esi_combined_size = 0 + if hji.sdk_api_job_info.extra_source_images: + for esi in hji.sdk_api_job_info.extra_source_images: + esi_combined_size += len(esi.image) + model_dump["sdk_api_job_info"]["extra_source_images_combined_size"] = esi_combined_size + model_dump["sdk_api_job_info"]["source_image_size"] = ( + len(hji.sdk_api_job_info.source_image) if hji.sdk_api_job_info.source_image else 0 + ) + model_dump["sdk_api_job_info"]["source_mask_size"] = ( + len(hji.sdk_api_job_info.source_mask) if hji.sdk_api_job_info.source_mask else 0 + ) if not os.path.exists(file_name_to_use): with open(file_name_to_use, "w") as f: json.dump([model_dump], f, indent=4) @@ -2752,65 +2784,135 @@ async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse if job_pop_response.id_ is None: logger.error("Received ImageGenerateJobPopResponse with id_ is None. Please let the devs know!") return job_pop_response - image_fields: list[str] = ["source_image", "source_mask"] - new_response_dict: dict[str, Any] = job_pop_response.model_dump(by_alias=True) + download_tasks: list[Task] = [] - # TODO: Move this into horde_sdk - for field_name in image_fields: - field_value = new_response_dict[field_name] - if field_value is not None and "https://" in field_value: - fail_count = 0 - while True: - try: - if fail_count >= 10: - new_meta_entry = GenMetadataEntry( - type=METADATA_TYPE[field_name], - value=METADATA_VALUE.download_failed, - ) - self.job_faults[job_pop_response.id_].append(new_meta_entry) - logger.error(f"Failed to download {field_name} after {fail_count} attempts") - break - response = await self._aiohttp_session.get( - field_value, - timeout=aiohttp.ClientTimeout(total=10), - ) - response.raise_for_status() + source_image_is_url = False + if job_pop_response.source_image is not None and job_pop_response.source_image.startswith("http"): + source_image_is_url = True + logger.debug(f"Source image for job {job_pop_response.id_} is a URL") - content = await response.content.read() + source_mask_is_url = False + if job_pop_response.source_mask is not None and job_pop_response.source_mask.startswith("http"): + source_mask_is_url = True + logger.debug(f"Source mask for job {job_pop_response.id_} is a URL") - new_response_dict[field_name] = base64.b64encode(content).decode("utf-8") + any_extra_source_images_are_urls = False + if job_pop_response.extra_source_images is not None: + for extra_source_image in job_pop_response.extra_source_images: + if extra_source_image.image.startswith("http"): + any_extra_source_images_are_urls = True + logger.debug(f"Extra source image for job {job_pop_response.id_} is a URL") - logger.debug(f"Downloaded {field_name} for job {job_pop_response.id_}") - break - except Exception as e: - logger.debug(f"{type(e)}: {e}") - logger.warning(f"Failed to download {field_name}: {e}") - fail_count += 1 - await asyncio.sleep(0.5) + attempts = 0 + while attempts < MAX_SOURCE_IMAGE_RETRIES: + if ( + source_image_is_url + and job_pop_response.source_image is not None + and job_pop_response.get_downloaded_source_image() is None + ): + download_tasks.append(job_pop_response.async_download_source_image(self._aiohttp_client_session)) + if ( + source_mask_is_url + and job_pop_response.source_mask is not None + and job_pop_response.get_downloaded_source_mask() is None + ): + download_tasks.append(job_pop_response.async_download_source_mask(self._aiohttp_client_session)) - for field in image_fields: - try: - field_value = new_response_dict[field] - if field_value is not None: - image = PIL.Image.open(BytesIO(base64.b64decode(field_value))) - image.verify() + download_extra_source_images = job_pop_response.get_downloaded_extra_source_images() + if ( + any_extra_source_images_are_urls + and job_pop_response.extra_source_images is not None + or ( + download_extra_source_images is not None + and job_pop_response.extra_source_images is not None + and len(download_extra_source_images) != len(job_pop_response.extra_source_images) + ) + ): - except Exception as e: - logger.error(f"Failed to verify {field}: {e}") - if job_pop_response.id_ is None: - raise ValueError("job_pop_response.id_ is None") from e + download_tasks.append( + asyncio.create_task( + job_pop_response.async_download_extra_source_images( + self._aiohttp_client_session, + max_retries=MAX_SOURCE_IMAGE_RETRIES, + ), + ), + ) + + gather_results = await asyncio.gather(*download_tasks, return_exceptions=True) + + for result in gather_results: + if isinstance(result, Exception): + logger.error(f"Failed to download source image: {result}") + attempts += 1 + break + else: + break + + if attempts >= MAX_SOURCE_IMAGE_RETRIES: + if source_image_is_url and job_pop_response.get_downloaded_source_image() is None: + if self.job_faults.get(job_pop_response.id_) is None: + self.job_faults[job_pop_response.id_] = [] - new_meta_entry = GenMetadataEntry( - type=METADATA_TYPE[field], - value=METADATA_VALUE.parse_failed, + logger.error(f"Failed to download source image for job {job_pop_response.id_}") + self.job_faults[job_pop_response.id_].append( + GenMetadataEntry( + type=METADATA_TYPE.source_image, + value=METADATA_VALUE.download_failed, + ref="source_image", + ), ) - self.job_faults[job_pop_response.id_].append(new_meta_entry) - new_response_dict[field] = None - new_response_dict["source_processing"] = "txt2img" + if source_mask_is_url and job_pop_response.get_downloaded_source_mask() is None: + if self.job_faults.get(job_pop_response.id_) is None: + self.job_faults[job_pop_response.id_] = [] + logger.error(f"Failed to download source mask for job {job_pop_response.id_}") + + self.job_faults[job_pop_response.id_].append( + GenMetadataEntry( + type=METADATA_TYPE.source_mask, + value=METADATA_VALUE.download_failed, + ref="source_mask", + ), + ) + downloaded_extra_source_images = job_pop_response.get_downloaded_extra_source_images() + if ( + any_extra_source_images_are_urls + and downloaded_extra_source_images is None + or ( + downloaded_extra_source_images is not None + and job_pop_response.extra_source_images is not None + and len(downloaded_extra_source_images) != len(job_pop_response.extra_source_images) + ) + ): + if self.job_faults.get(job_pop_response.id_) is None: + self.job_faults[job_pop_response.id_] = [] + logger.error(f"Failed to download extra source images for job {job_pop_response.id_}") + + ref = [] + if job_pop_response.extra_source_images is not None and downloaded_extra_source_images is not None: + for predownload_extra_source_image in job_pop_response.extra_source_images: + if predownload_extra_source_image.image.startswith("http"): + if any( + predownload_extra_source_image.original_url == extra_source_image.image + for extra_source_image in downloaded_extra_source_images + ): + continue - return ImageGenerateJobPopResponse(**new_response_dict) + ref.append(str(job_pop_response.extra_source_images.index(predownload_extra_source_image))) + elif job_pop_response.extra_source_images is not None and downloaded_extra_source_images is None: + ref = [str(i) for i in range(len(job_pop_response.extra_source_images))] + + for r in ref: + self.job_faults[job_pop_response.id_].append( + GenMetadataEntry( + type=METADATA_TYPE.extra_source_images, + value=METADATA_VALUE.download_failed, + ref=r, + ), + ) + + return job_pop_response _last_pop_no_jobs_available: bool = False @@ -3003,6 +3105,7 @@ async def api_job_pop(self) -> None: else: logger.error(f"Failed to pop job (API Error): {job_pop_response}") self._job_pop_frequency = self._error_job_pop_frequency + self._last_pop_no_jobs_available = True return except Exception as e: @@ -3147,8 +3250,8 @@ async def _job_submit_loop(self) -> None: async def _api_call_loop(self) -> None: """Run the API call loop for popping jobs and doing miscellaneous API calls.""" logger.debug("In _api_call_loop") - self._aiohttp_session = ClientSession(requote_redirect_url=False) - async with self._aiohttp_session as aiohttp_session: + self._aiohttp_client_session = ClientSession(requote_redirect_url=False) + async with self._aiohttp_client_session as aiohttp_session: self.horde_client_session = AIHordeAPIAsyncClientSession(aiohttp_session=aiohttp_session) self.horde_client = AIHordeAPIAsyncSimpleClient( aiohttp_session=None, @@ -3601,6 +3704,12 @@ def replace_hung_processes(self) -> bool: """Replaces processes that haven't checked in since `process_timeout` seconds in bridgeData.""" now = time.time() + import threading + + def timed_unset_recently_recovered() -> None: + time.sleep(60) + self._recently_recovered = False + # If every process hasn't done anything for a while or if we haven't submitted a job for a while, # AND the last job pop returned a job, we're in a black hole and we need to exit because none of the ways to # recover worked @@ -3611,7 +3720,6 @@ def replace_hung_processes(self) -> bool: ) or ((now - self._last_job_submitted_time) > self.bridge_data.process_timeout) ) and not (self._last_pop_no_jobs_available or self._recently_recovered): - self._cleanup_jobs() if self.bridge_data.exit_on_unhandled_faults: @@ -3632,12 +3740,6 @@ def replace_hung_processes(self) -> bool: if process_info.process_type == HordeProcessType.INFERENCE: self._replace_inference_process(process_info) - def timed_unset_recently_recovered() -> None: - time.sleep(60) - self._recently_recovered = False - - import threading - threading.Thread(target=timed_unset_recently_recovered).start() return True @@ -3645,12 +3747,16 @@ def timed_unset_recently_recovered() -> None: if self._shutting_down: return False + if self._last_pop_no_jobs_available or self._recently_recovered: + return False + any_replaced = False for process_info in self._process_map.values(): if self._process_map.is_stuck_on_inference(process_info.process_id): logger.error(f"{process_info} seems to be stuck mid inference, replacing it") self._replace_inference_process(process_info) any_replaced = True + self._recently_recovered = True else: conditions: list[tuple[float, HordeProcessState, str]] = [ ( @@ -3682,6 +3788,10 @@ def timed_unset_recently_recovered() -> None: for timeout, state, error_message in conditions: if self._check_and_replace_process(process_info, timeout, state, error_message): any_replaced = True + self._recently_recovered = True break + if any_replaced: + threading.Thread(target=timed_unset_recently_recovered).start() + return any_replaced diff --git a/pyproject.toml b/pyproject.toml index 172ffd71..e0126822 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "horde_worker_regen" -version = "4.3.9" +version = "5.0.1" description = "Allows you to connect to the AI Horde and generate images for users." authors = [ {name = "tazlin", email = "tazlin.on.github@gmail.com"}, @@ -101,6 +101,9 @@ exclude = ''' | build | dist | conda + | clip_blip + | hf_transformers + | horde_model_reference )/ ''' diff --git a/requirements.txt b/requirements.txt index a9992279..9ebaedd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch>=2.1.2 -horde_sdk~=0.8.3 +horde_sdk~=0.9.2 horde_safety~=0.2.3 -hordelib~=2.7.6 +hordelib~=2.8.1 horde_model_reference~=0.6.3 python-dotenv