Skip to content

Commit

Permalink
fix: img2img not downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Oct 1, 2023
1 parent 723ff10 commit ca0d655
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
61 changes: 56 additions & 5 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections import deque
from collections.abc import Mapping
from io import BytesIO
from multiprocessing.connection import PipeConnection
from multiprocessing.context import BaseContext
from multiprocessing.synchronize import Lock as Lock_MultiProcessing
from multiprocessing.synchronize import Semaphore
Expand Down Expand Up @@ -58,10 +57,15 @@
)
from horde_worker_regen.process_management.worker_entry_points import start_inference_process, start_safety_process

try:
from multiprocessing.connection import PipeConnection as Connection
except ImportError:
from multiprocessing.connection import Connection # type: ignore


class HordeProcessInfo:
mp_process: multiprocessing.Process
pipe_connection: PipeConnection
pipe_connection: Connection
process_id: int
process_type: HordeProcessKind
last_process_state: HordeProcessState
Expand All @@ -76,7 +80,7 @@ class HordeProcessInfo:
def __init__(
self,
mp_process: multiprocessing.Process,
pipe_connection: PipeConnection,
pipe_connection: Connection,
process_id: int,
process_type: HordeProcessKind,
last_process_state: HordeProcessState,
Expand Down Expand Up @@ -1062,8 +1066,8 @@ async def api_job_pop(self) -> None:
job_pop_request = ImageGenerateJobPopRequest(
apikey=self.bridge_data.api_key,
name=self.bridge_data.dreamer_worker_name,
bridge_agent="AI Horde Worker:23:tazlin reGen testing",
bridge_version=23, # TODO TIs broken
bridge_agent="AI Horde Worker reGen:1:https://github.com/Haidra-Org/",
bridge_version=1, # TODO TIs broken
models=self.bridge_data.image_models_to_load,
nsfw=self.bridge_data.nsfw,
threads=self.max_concurrent_inference_processes,
Expand Down Expand Up @@ -1093,6 +1097,11 @@ async def api_job_pop(self) -> None:

self._job_pop_frequency = self._default_job_pop_frequency

info_string = "No job available. "
if len(self.job_deque) > 0:
info_string += f"Current job deque length: {len(self.job_deque)}. "
info_string += f"(Skipped reasons: {job_pop_response.skipped.model_dump(exclude_defaults=True)})"

if job_pop_response.id_ is None:
logger.info(
f"No job available. (Skipped reasons: {job_pop_response.skipped.model_dump(exclude_defaults=True)})",
Expand All @@ -1107,6 +1116,48 @@ async def api_job_pop(self) -> None:
new_response_dict["payload"]["seed"] = random.randint(0, (2**32) - 1)
job_pop_response = ImageGenerateJobPopResponse(**new_response_dict)

if job_pop_response.source_image is not None and "https://" in job_pop_response.source_image:
# Download and convert the source image to base64
fail_count = 0
while True:
try:
if fail_count >= 10:
logger.error(f"Failed to download source image after {fail_count} attempts")
break
source_image_response = requests.get(job_pop_response.source_image)
source_image_response.raise_for_status()
new_response_dict = job_pop_response.model_dump(by_alias=True)

new_response_dict["source_image"] = base64.b64encode(source_image_response.content).decode("utf-8")
job_pop_response = ImageGenerateJobPopResponse(**new_response_dict)
logger.debug(f"Downloaded source image for job {job_pop_response.id_}")
break
except Exception as e:
logger.error(f"Failed to download source image: {e}")
fail_count += 1
time.sleep(0.5)

if job_pop_response.source_mask is not None and "https://" in job_pop_response.source_mask:
# Download and convert the source image to base64
fail_count = 0
while True:
try:
if fail_count >= 10:
logger.error(f"Failed to download source image after {fail_count} attempts")
break
source_mask_response = requests.get(job_pop_response.source_mask)
source_mask_response.raise_for_status()
new_response_dict = job_pop_response.model_dump(by_alias=True)

new_response_dict["source_mask"] = base64.b64encode(source_mask_response.content).decode("utf-8")
job_pop_response = ImageGenerateJobPopResponse(**new_response_dict)
logger.debug(f"Downloaded source image for job {job_pop_response.id_}")
break
except Exception as e:
logger.error(f"Failed to download source_mask: {e}")
fail_count += 1
time.sleep(0.5)

async with self._job_deque_lock:
self.job_deque.append(job_pop_response)
self._testing_jobs_added += 1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch

horde_sdk
horde_sdk>=7.10.0
horde_model_reference
horde_safety
hordelib
Expand Down
4 changes: 3 additions & 1 deletion run_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ def main(ctx: BaseContext) -> None:

imlr = ImageModelLoadResolver(horde_model_reference_manager)

resolved_models = None
if bridge_data.meta_load_instructions is not None:
resolved_models = imlr.resolve_meta_instructions(
list(bridge_data.meta_load_instructions),
AIHordeAPIManualClient(),
)

bridge_data.image_models_to_load = list(resolved_models)
if resolved_models is not None:
bridge_data.image_models_to_load = list(set(bridge_data.image_models_to_load + list(resolved_models)))

start_working(ctx=ctx, bridge_data=bridge_data)

Expand Down

0 comments on commit ca0d655

Please sign in to comment.