Skip to content

Commit

Permalink
Merge pull request #71 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix: inject negative embeddings correctly
  • Loading branch information
tazlin authored Sep 6, 2023
2 parents 81f0885 + 87e7c57 commit 2791575
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 39 deletions.
27 changes: 18 additions & 9 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def _apply_aihorde_compatibility_hacks(self, payload):
return payload

def _final_pipeline_adjustments(self, payload, pipeline_data):

payload = deepcopy(payload)

# Process dynamic prompts
Expand Down Expand Up @@ -308,23 +307,31 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
continue
ti_inject = ti.get("inject_ti")
ti_strength = ti.get("strength", 1.0)
try:
ti_strength = float(ti_strength)
except (TypeError, ValueError):
ti_strength = 1.0
if type(ti_strength) not in [float, int]:
ti_strength = 1.0
ti_id = SharedModelManager.manager.ti.get_ti_id(str(ti["name"]))
if ti_inject == "prompt":
payload["prompt"] = f'(embedding:{ti_id}:{ti_strength}),{payload["prompt"]}'
elif ti_inject == "negprompt":
if "###" not in payload["prompt"]:
payload["prompt"] += "###"
payload["prompt"] = f'{payload["prompt"]},(embedding:{ti_id}:{ti_strength})'
SharedModelManager.manager.ti.touch_ti(ti_name)
# create negative prompt if empty
if "negative_prompt" not in payload:
payload["negative_prompt"] = ""

had_leading_comma = payload["negative_prompt"].startswith(",")

payload["negative_prompt"] = f'{payload["negative_prompt"]},(embedding:{ti_id}:{ti_strength})'
if not had_leading_comma:
payload["negative_prompt"] = payload["negative_prompt"].lstrip(",")
# Setup controlnet if required
# For LORAs we completely build the LORA section of the pipeline dynamically, as we have
# to handle n LORA models which form chained nodes in the pipeline.
# Note that we build this between several nodes, the model_loader, clip_skip and the sampler,
# plus the upscale sampler (used in hires fix) if there is one
if payload.get("loras") and SharedModelManager.manager.lora:

# Remove any requested LORAs that we don't have
valid_loras = []
for lora in payload.get("loras"):
Expand Down Expand Up @@ -366,7 +373,6 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
valid_loras.append(lora)
payload["loras"] = valid_loras
for lora_index, lora in enumerate(payload.get("loras")):

# Inject a lora node (first lora)
if lora_index == 0:
pipeline_data[f"lora_{lora_index}"] = {
Expand Down Expand Up @@ -395,7 +401,6 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
}

for lora_index, lora in enumerate(payload.get("loras")):

# The first LORA always connects to the model loader
if lora_index == 0:
self.generator.reconnect_input(pipeline_data, "lora_0.model", "model_loader")
Expand Down Expand Up @@ -545,7 +550,7 @@ def unlock_models(self, models):
self.generator.unlock_models(models)
logger.debug(f"Unlocked models {','.join(models)}")

def basic_inference(self, payload, rawpng=False):
def _get_validated_payload_and_pipeline_data(self, payload) -> tuple[dict, dict]:
# AIHorde hacks to payload
payload = self._apply_aihorde_compatibility_hacks(payload)
# Check payload types/values and normalise it's format
Expand All @@ -557,6 +562,10 @@ def basic_inference(self, payload, rawpng=False):
# Final adjustments to the pipeline
pipeline_data = self.generator.get_pipeline_data(pipeline)
payload = self._final_pipeline_adjustments(payload, pipeline_data)
return payload, pipeline_data

def basic_inference(self, payload, rawpng=False):
payload, pipeline_data = self._get_validated_payload_and_pipeline_data(payload)
models: list[str] = []
# Run the pipeline
try:
Expand Down
Binary file added images_expected/ti_bad_inject.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images_expected/ti_basic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images_expected/ti_inject.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 57 additions & 30 deletions tests/test_horde_ti.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
# test_horde_ti.py
import os
from datetime import datetime, timedelta
from pathlib import Path

import pytest
from PIL import Image

from hordelib.horde import HordeLib
from hordelib.shared_model_manager import SharedModelManager

from .testing_shared_functions import check_single_lora_image_similarity


class TestHordeTI:
def test_basic_ti(
@pytest.fixture(scope="class")
def basic_ti_payload_data(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
assert shared_model_manager.manager.ti

data = {
) -> dict:
return {
"sampler_name": "k_euler",
"cfg_scale": 8.0,
"denoising_strength": 1.0,
Expand All @@ -33,9 +30,9 @@ def test_basic_ti(
"control_type": None,
"image_is_control": False,
"return_control_map": False,
"prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, "
"atmospheric lighting, embedding:7523###(embedding:7808:0.5), embedding:64870",
"prompt": "(embedding:7523:1.0),Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, (holding a bundle of flowers:1.2), detailed background, light rays, "
"atmospheric lighting###(embedding:7808:0.5),(embedding:64870:1.0)",
"tis": [
{"name": 7523},
{"name": 7808},
Expand All @@ -46,7 +43,15 @@ def test_basic_ti(
"model": stable_diffusion_model_name_for_testing,
}

pil_image = hordelib_instance.basic_inference(data)
def test_basic_ti(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
basic_ti_payload_data,
):
assert shared_model_manager.manager.ti

pil_image = hordelib_instance.basic_inference(basic_ti_payload_data)
assert pil_image is not None
assert (
Path(os.path.join(shared_model_manager.manager.ti.modelFolderPath, "64870.safetensors")).exists() is True
Expand All @@ -57,9 +62,9 @@ def test_basic_ti(

def test_inject_ti(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
basic_ti_payload_data: dict,
):
data = {
"sampler_name": "k_euler",
Expand All @@ -76,32 +81,44 @@ def test_inject_ti(
"image_is_control": False,
"return_control_map": False,
"prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, "
"walking in a field of flowers, (holding a bundle of flowers:1.2), detailed background, light rays, "
"atmospheric lighting",
"tis": [
{"name": 7523, "inject_ti": "prompt", "strength": 0.5},
{"name": 7523, "inject_ti": "prompt", "strength": 1.0},
{"name": 7808, "inject_ti": "negprompt", "strength": 0.5},
{"name": 64870, "inject_ti": "negprompt", "strength": 0.5},
{"name": 64870, "inject_ti": "negprompt", "strength": 1.0},
],
"ddim_steps": 20,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}

payload, _ = hordelib_instance._get_validated_payload_and_pipeline_data(data)

basic_payload, _ = hordelib_instance._get_validated_payload_and_pipeline_data(
basic_ti_payload_data,
)

assert payload["prompt.text"] == basic_payload["prompt.text"]
assert payload["negative_prompt.text"] == basic_payload["negative_prompt.text"]

assert "(embedding:7523:1.0)" in payload["prompt.text"]
assert "(embedding:7808:0.5)" in payload["negative_prompt.text"]
assert "(embedding:64870:1.0)" in payload["negative_prompt.text"]

pil_image = hordelib_instance.basic_inference(data)
assert pil_image is not None

img_filename = "ti_inject.png"
pil_image.save(f"images/{img_filename}", quality=100)

# assert check_single_lora_image_similarity(
# f"images_expected/{img_filename}",
# pil_image,
# )
assert check_single_lora_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)

def test_bad_inject_ti(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
Expand All @@ -120,22 +137,32 @@ def test_bad_inject_ti(
"image_is_control": False,
"return_control_map": False,
"prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, "
"walking in a field of flowers, (holding a bundle of flowers:1.2), detailed background, light rays, "
"atmospheric lighting",
"tis": [
{"name": 7523, "inject_ti": "prompt", "strength": "0.5"},
{"name": 7808, "inject_ti": "negprompt", "strength": None},
{"name": 64870, "inject_ti": "YOLO", "strength": "YOLO"},
{"name": 7523, "inject_ti": "prompt", "strength": None},
{"name": 7808, "inject_ti": "negprompt", "strength": "0.5"},
{"name": 64870, "inject_ti": "negprompt", "strength": "1.0"},
{"name": 4629, "inject_ti": "YOLO", "strength": "YOLO"},
],
"ddim_steps": 20,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}

payload, _ = hordelib_instance._get_validated_payload_and_pipeline_data(data)

assert "(embedding:7523:1.0)" in payload["prompt.text"]
assert "(embedding:7808:0.5)" in payload["negative_prompt.text"]
assert "(embedding:64870:1.0)" in payload["negative_prompt.text"]

pil_image = hordelib_instance.basic_inference(data)
assert pil_image is not None

# assert check_single_lora_image_similarity(
# f"images_expected/{img_filename}",
# pil_image,
# )
img_filename = "ti_bad_inject.png"
pil_image.save(f"images/{img_filename}", quality=100)

assert check_single_lora_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)

0 comments on commit 2791575

Please sign in to comment.