Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRa support to the txt2img and img2img pipelines #119

Merged
merged 34 commits into from
Sep 22, 2024
Merged

Conversation

stronk-dev
Copy link
Contributor

@stronk-dev stronk-dev commented Jul 10, 2024

Adds support to load in arbitrary embeddings, modules (like LCM), etc.

Fullfills livepeer/bounties#33

Still requires:
- Testing if it works
- Gracefully deal with non-existing requested LoRas
Design decision: do we want to keep LoRas loaded, or always unload already loaded weights like we do now
Design decision: use the current method of requesting LoRas, or explore other options
Design decision: abort inference if one of the loras param is invalid or it fails to load one of the LoRas, or continue on like it does now

LoRas can be loaded by passing a new loras parameter. In the current design this needs to be a string, parseable as JSON. For example: curl -X POST -H "Content-Type: application/json" localhost:8000/text-to-image -d '{"prompt":"light saber battle in the death star", "loras": "{ \"nerijs/pixel-art-xl\" : 1.2 }"}'

@rickstaa rickstaa force-pushed the main branch 3 times, most recently from cd1feb4 to 0d03040 Compare July 16, 2024 13:10
@stronk-dev
Copy link
Contributor Author

Did some testing:

2024-07-16 13:38:15,602 INFO:     Application startup complete.
2024-07-16 13:38:15,604 INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
100%|██████████| 50/50 [00:06<00:00,  7.93it/s]
2024-07-16 13:38:23,625 INFO:     172.17.0.1:52384 - "POST /text-to-image HTTP/1.1" 200 OK
100%|██████████| 50/50 [00:08<00:00,  6.24it/s]
2024-07-16 13:39:10,578 INFO:     172.17.0.1:35758 - "POST /text-to-image HTTP/1.1" 200 OK
100%|██████████| 50/50 [00:06<00:00,  8.13it/s]
2024-07-16 13:39:22,707 INFO:     172.17.0.1:34084 - "POST /text-to-image HTTP/1.1" 200 OK
100%|██████████| 50/50 [00:08<00:00,  6.23it/s]
2024-07-16 13:39:36,599 INFO:     172.17.0.1:39376 - "POST /text-to-image HTTP/1.1" 200 OK

All images on a 4090, SDXL base model using prompt light saber battle in the death star. Requests 1 and 3 are without LoRas. Requests 2 and 4 requested the nerijs/pixel-art-xl LoRa.

img1
img2
img3
img4

Interesting inference is a bit slower when using LoRas. I could've used a better trigger words in the prompt, but certainly seems like the LoRa was loaded.

One thing which we might want to rethink if the way to pass loras parameter:

curl -X POST -H "Content-Type: application/json" localhost:8000/text-to-image -d '{"prompt":"light saber battle in the death star", "loras": "{ \"nerijs/pixel-art-xl\" : 1.2 }"}'

@stronk-dev stronk-dev marked this pull request as ready for review July 16, 2024 13:51
@stronk-dev
Copy link
Contributor Author

If the user requests an invalid LoRa repo, it will print the error 2024-07-16 14:14:03,062 - app.pipelines.util - WARNING - Unable to load LoRas for adapter 'nerijs/pixel-ar' (RepositoryNotFoundError)

We can make this more verbose by printing the entire exeption. The runner will continue with inference, but without using the LoRa

@stronk-dev
Copy link
Contributor Author

(as a sidenote: i think it would be useful if exceptions like that are collected and passed back. Make a best effort to complete inference and inform the user of any issues it found during the job. Alternatively we could also abort inference)

@eliteprox
Copy link
Collaborator

If the user requests an invalid LoRa repo, it will print the error 2024-07-16 14:14:03,062 - app.pipelines.util - WARNING - Unable to load LoRas for adapter 'nerijs/pixel-ar' (RepositoryNotFoundError)

We can make this more verbose by printing the entire exeption. The runner will continue with inference, but without using the LoRa

I like your LoRa input validation solution because it handles all incorrect values sufficiently. However, we might want to return these errors to the gateway later to inform the user. I think we should return the error from load_loras and return a bad request in the runner in case of an invalid lora or weight so go-livepeer can return it. I like the error messages you have now, I don't think more detail is needed on them.

We could also hold off on making that change until we develop the go-livepeer side, @rickstaa any thoughts on that approach?

@eliteprox
Copy link
Collaborator

eliteprox commented Aug 7, 2024

Design decision: do we want to keep LoRas loaded, or always unload already loaded weights like we do now

VRAM usage looks great, I tested with multiple concurrent requests of different loras. I think this is working good as it is. If this would enhance inference time, I think we should backlog it as a pipeline improvement.

Design decision: use the current method of requesting LoRas, or explore other options

This implementation appears to be working great, I tried a few LoRas from hugging-face and they are downloaded automatically

Copy link
Collaborator

@eliteprox eliteprox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @stronk-dev! This a nice addition to the image pipelines. I tested both text-to-image and image-to-image using the ByteDance/SDXL-Lightning mdoel with two different loras and multiple images. The pipelines are working great with LoRa support.

See my comments above on the remaining design decisions. I think responding with an informative bad request response when invalid LoRa values are passed will help inform the user on the gateway side. If you can make that change (or we decide to do it during go-livepeer integration) and resolve conflicts then the PR looks good to me.

@eliteprox eliteprox force-pushed the feature/loras branch 3 times, most recently from 6544ef0 to f6c7e94 Compare August 10, 2024 00:57
@rickstaa
Copy link
Member

Intersting 🤔! Not sure what went on during the OpenAPI spec configuration -> dcaa961. Maybe my generation code is no longer sufficient. Will revert dcaa961 and check on monday.

@stronk-dev
Copy link
Contributor Author

See my comments above on the remaining design decisions. I think responding with an informative bad request response when invalid LoRa values are passed will help inform the user on the gateway side. If you can make that change (or we decide to do it during go-livepeer integration) and resolve conflicts then the PR looks good to me.

Just to confirm: we should return a bad request error and abort inference for any of the Exceptions in load_loras function?

Copy link
Collaborator

@eliteprox eliteprox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is ready merge if there's no further optimizations. Please review @stronk-dev and @stronk-dev if you could take a quick look. I've fully tested both text-to-image and image-to-image pipelines with various loras

@eliteprox
Copy link
Collaborator

eliteprox commented Aug 23, 2024

See my comments above on the remaining design decisions. I think responding with an informative bad request response when invalid LoRa values are passed will help inform the user on the gateway side. If you can make that change (or we decide to do it during go-livepeer integration) and resolve conflicts then the PR looks good to me.

Just to confirm: we should return a bad request error and abort inference for any of the Exceptions in load_loras function?

Correct and I'd like to send the specific message back to the gateway. I think this would help developers learn the API faster, but not required for this initial release of LoRa support in my opinion, open to suggestions

@stronk-dev
Copy link
Contributor Author

Thanks for all the tweaks, @eliteprox ! LGTM

@eliteprox
Copy link
Collaborator

Thanks for all the tweaks, @eliteprox ! LGTM

@rickstaa I've resolved conflicts on runner.gen.go and re-generated the openapi schema. This is ready for merge along with livepeer/go-livepeer#3154

@eliteprox
Copy link
Collaborator

@stronk-dev Just an update on this PR. There is one remaining task to load the 2-step, 4-step or 8-step checkpoints differently when LoRas are used on non-sdxl models. See https://huggingface.co/ByteDance/SDXL-Lightning#2-step-4-step-8-step-lora

This logic is in t2i and i2i:
https://github.com/livepeer/ai-worker/blob/feature/loras/runner/app/pipelines/text_to_image.py#L86
https://github.com/livepeer/ai-worker/blob/feature/loras/runner/app/pipelines/image_to_image.py#L76

@rickstaa
Copy link
Member

(as a sidenote: i think it would be useful if exceptions like that are collected and passed back. Make a best effort to complete inference and inform the user of any issues it found during the job. Alternatively we could also abort inference)

@stronk-dev I agree -> #188.

@rickstaa
Copy link
Member

See my comments above on the remaining design decisions. I think responding with an informative bad request response when invalid LoRa values are passed will help inform the user on the gateway side. If you can make that change (or we decide to do it during go-livepeer integration) and resolve conflicts then the PR looks good to me.

@eliteprox good catch. Could you create a seperate linear item for it?

@rickstaa
Copy link
Member

@stronk-dev thanks again for your contribution. To ensure optimal network performance and leaner code. I applied several improvements:

  • I migrated the code into its own class.
  • I added some logic to disable loras instead of unloading them when a request without loras comes through.
  • I added loaded lora tracking so we only unload loras when needed.

Output tests

I performed several output tests to ensure the loras were being applied and correctly disabled.

Prompt: pixel, a cute corgi.
Negative prompt: 3d render, realistic.
Base: stabilityai/stable-diffusion-xl-base-1.0.
Loras: { \"latent-consistency/lcm-lora-sdxl\": 1.0, \"nerijs/pixel-art-xl\": 1.2}.
Guidance scale: 1.5.
Inference steps : 50.

Without loras

{
  "model_id": "",
  "loras": "",
  "prompt": "pixel, a cute corgi",
  "height": 576,
  "width": 1024,
  "guidance_scale": 1.5,
  "negative_prompt": "3d render, realistic",
  "safety_check": true,
  "seed": 0,
  "num_inference_steps": 50,
  "num_images_per_prompt": 1
}

image

With loras

{
  "model_id": "",
  "loras": "{ \"latent-consistency/lcm-lora-sdxl\": 1.0, \"nerijs/pixel-art-xl\": 1.2}",
  "prompt": "pixel, a cute corgi",
  "height": 576,
  "width": 1024,
  "guidance_scale": 1.5,
  "negative_prompt": "3d render, realistic",
  "safety_check": true,
  "seed": 0,
  "num_inference_steps": 50,
  "num_images_per_prompt": 1
}

image

Next request wihout loras

image

Following request with loras again

image

Request without loras and negative prompt

image

Request with better parameters after lora was enabled

{
  "model_id": "",
  "loras": "",
  "prompt": "close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light",
  "height": 576,
  "width": 1024,
  "guidance_scale": 7.5,
  "negative_prompt": "",
  "safety_check": true,
  "seed": 0,
  "num_inference_steps": 50,
  "num_images_per_prompt": 1
}

image

Request with better parameters on restart

{
  "model_id": "",
  "loras": "",
  "prompt": "close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light",
  "height": 576,
  "width": 1024,
  "guidance_scale": 7.5,
  "negative_prompt": "",
  "safety_check": true,
  "seed": 0,
  "num_inference_steps": 50,
  "num_images_per_prompt": 1
}

image

This commit applies several performance optimizations that allow the
loras to be kept in memory to be used for similar requests.
@rickstaa
Copy link
Member

rickstaa commented Sep 22, 2024

@eliteprox, @stronk-dev I will add one more optimization where we allow up to 4 loras but allow 8 loras in buffer and I will look at memory crashes. After that it is good to be merged.

This commit introduces a buffer to keep LoRas in memory up to a
certain size. This optimization prevents unnecessary reloads when
the orchestrator receives repeated requests, improving network
performance. While PyTorch handles memory cleanup when limits are
reached, we can call `torch.cuda.empty_cache()` after the
`delete_adapters` function call if we encounter frequent
out-of-memory errors.
This commit adds some logic which cleans up loras from GPU memory when
the free memory on the GPU goes below 2 GB. It also increase the max
loras on the GPU to 12.
@rickstaa
Copy link
Member

@eliteprox and @stronk-dev I now also added some logic to cleanup memory when the free memory on the GPU goes below 2 GB.

This commit cleans up some unused code.
This commit applies the black formatter to the lora related code.
This commit moves the LoraLoadingError closer to the LoraLoader.
@rickstaa
Copy link
Member

Another Check

Base

curl -X POST "http://0.0.0.0:8935/text-to-image" \
    -H "Content-Type: application/json" \
    -d '{
        "model_id":"stabilityai/stable-diffusion-xl-base-1.0",
        "loras": "{ \"latent-consistency/lcm-lora-sdxl\": 0.0, \"KappaNeuro/jim-mahfood-style\": 1.0}",
        "prompt":"A cool cat on the beach",
        "width": 1024,
        "height": 1024,
        "seed": 818566848
    }'

818566848

Loras

curl -X POST "http://0.0.0.0:8935/text-to-image" \
    -H "Content-Type: application/json" \
    -d '{
        "model_id":"stabilityai/stable-diffusion-xl-base-1.0",
        "loras": "{ \"alvdansen/dimension-w-sd15\": 1.0}",
        "prompt":"A cool cat on the beach",
        "width": 1024,
        "height": 1024,
        "seed": 818566848
    }'

35fe8924

Back to base

e212c406

Copy link
Collaborator

@eliteprox eliteprox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested loading/unloading LoRa, max of 4 LoRas and validation of invalid LoRa values.

LGTM!

@rickstaa rickstaa merged commit bcd929d into main Sep 22, 2024
2 of 3 checks passed
@rickstaa rickstaa deleted the feature/loras branch September 22, 2024 21:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants