Skip to content

Commit

Permalink
lora pipeline fixes (#150)
Browse files Browse the repository at this point in the history
* fix go-binding, fix python error when invalid lora provided

* fix loading/unloading of lora weights
  • Loading branch information
eliteprox authored Aug 22, 2024
1 parent f6c7e94 commit e17dff2
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 41 deletions.
11 changes: 5 additions & 6 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def list(cls):

class TextToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.loaded_loras = ""
self.loaded_loras = []
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

Expand Down Expand Up @@ -204,15 +204,14 @@ def __call__(
]

# Dynamically (un)load LoRas. Defaults to "" when not passed, so should always be present in kwargs
if self.loaded_loras != kwargs["loras"]:
if kwargs["loras"] is None:
# Unload previously loaded LoRas
# NOTE: we might want to keep LoRas loaded and only reset their weights
# TODO: run tests with VRAM usage. We should be able to keep the last x LoRas loaded without issues
if self.loaded_loras != "":
self.ldm.unload_lora_weights()
self.ldm.unload_lora_weights()
else:
# Remember requested LoRas and their weights
self.loaded_loras = kwargs["loras"]
load_loras(self.ldm, self.loaded_loras)
self.loaded_loras = load_loras(self.ldm, kwargs["loras"], self.loaded_loras)
# Do not pass the lora param to the model when running inference
del kwargs["loras"]

Expand Down
29 changes: 18 additions & 11 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def check_nsfw_images(
)
return images, has_nsfw_concept

def load_loras(pipeline: any, requested_loras: str):
def load_loras(pipeline: any, requested_lora: str, loaded_loras: list) -> list:
"""Loads LoRas and sets their weights into the given pipeline.
Args:
Expand All @@ -186,20 +186,26 @@ def load_loras(pipeline: any, requested_loras: str):
if requested_loras == "" or requested_loras == None:
return;
# Parse LoRas param as JSON to extract key-value pairs
# Build a list of adapter names and their requested strength
adapters = []
strengths = []
if len(loaded_loras) > 0:
pipeline.unload_lora_weights()

try:
loras = json.loads(requested_loras)
lora = json.loads(requested_lora)
except Exception as e:
logger.warning(
"Unable to parse '" + requested_loras + "' as JSON. Continuing inference without loading LoRas"
f"Unable to parse '{requested_lora}' as JSON. Continuing inference without loading this LoRa"
)
return
# Build a list of adapter names and their requested strength
adapters = []
strengths = []
for adapter, val in loras.items():

for adapter, val in lora.items():
if adapter in loaded_loras:
pipeline.unload_lora_weights()

# Sanity check: strength should be a number with a minimum value of 0.0
try:
strength = int(val)
strength = float(val)
except ValueError:
logger.warning(
"Skipping requested LoRa " + adapter + ", as it's requested strength (" + val + ") is not a number"
Expand All @@ -210,7 +216,6 @@ def load_loras(pipeline: any, requested_loras: str):
"Clipping strength of LoRa " + adapter + " to 0.0, as it's requested strength (" + val + ") is negative"
)
strength = 0.0
# Load in LoRa weights if its repository exists on HuggingFace
try:
# TODO: If we decide to keep LoRas loaded (and only set their weight to 0), make sure that reloading them causes no performance hit or other issues
pipeline.load_lora_weights(adapter, adapter_name=adapter)
Expand All @@ -223,4 +228,6 @@ def load_loras(pipeline: any, requested_loras: str):
adapters.append(adapter)
strengths.append(strength)
# Set weights for all loaded adapters
pipeline.set_adapters(adapters, strengths)
if len(adapters) > 0:
pipeline.set_adapters(adapters, strengths)
return adapters
2 changes: 1 addition & 1 deletion runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TextToImageParams(BaseModel):
str,
Field(default="", description=""),
]
loras: Annotated[str, Field(loras="")]
loras: Annotated[str, Field(loras="")]=None
prompt: Annotated[str, Field(description="")]
height: Annotated[int, Field(default=576, description="")]
width: Annotated[int, Field(default=1024, description="")]
Expand Down
2 changes: 1 addition & 1 deletion runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@
"loras": {
"type": "string",
"title": "Loras",
"default": ""
"loras": ""
},
"prompt": {
"type": "string",
Expand Down
48 changes: 26 additions & 22 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e17dff2

Please sign in to comment.