Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lora Dynamic switching for inference #71

Open
Jeevi10 opened this issue Aug 14, 2024 · 4 comments
Open

Add Lora Dynamic switching for inference #71

Jeevi10 opened this issue Aug 14, 2024 · 4 comments

Comments

@Jeevi10
Copy link

Jeevi10 commented Aug 14, 2024

Dynamic LoRA (Low-Rank Adaptation) switching functionality, allowing users to change LoRA models on-the-fly during inference without reloading the entire model.

@StephennFernandes
Copy link

@Jeevi10 hey can you link some resources on Dynamic LoRA specifically for whisper , mainly how this type of inference works and how to use LoRA to finetune whisper

@Jeevi10
Copy link
Author

Jeevi10 commented Aug 26, 2024

@StephennFernandes Thank you for your reply.

Resources for dynamic lora:

https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/enc_dec#run-bart-with-lora
https://github.com/cccntu/minLoRA/tree/main
https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#run-llama-with-several-lora-checkpoints
https://github.com/S-LoRA/S-LoRA

I have provided some example repos where I got the idea from. Unfortunately I don't see any specific implementations for whisper directly.

Just to provide you an idea I created running example using huggingface transformers and peft,

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
from peft import PeftModel
import torch_tensorrt

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "distil-whisper/distil-large-v3"

base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)

base_model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

peft_model_id = "path to checkpoint adapter 1"
peft_model_id_2 = "path to checkpoint adapter2"
model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name='adapter 1', device_map="auto")
model.load_adapter(peft_model_id_2, adapter_name='adapter 2')

Enable static cache and compile the forward pass

model.generation_config.cache_implementation = "static"
#model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model = torch.compile(model, backend="torch_tensorrt",dynamic=False)

pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch.float16,
device=f"cuda:{0}",
model_kwargs={"attn_implementation": "flash_attention_2"},
)

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

def iterate_data(dataset):
for i, item in enumerate(dataset):
yield item["audio"]

set the batch size in accordance to your device

BATCH_SIZE = 16

predictions = []

run streamed inference adapter 1

for out in pipe(iterate_data(dataset), batch_size=BATCH_SIZE):
predictions.append(out["text"])

print(predictions)

pipe.model.set_adapter('adapter 2')

run streamed inference adapter 2

for out in pipe(iterate_data(dataset), batch_size=BATCH_SIZE):
predictions.append(out["text"])

print(predictions)

Whisper Finetuning with lora

https://github.com/Vaibhavs10/fast-whisper-finetuning

@StephennFernandes
Copy link

@Jeevi10 thanks for the heads up.

I'll try to write an update for WhisperS2T for being able to use dynamic adapters

@Jeevi10
Copy link
Author

Jeevi10 commented Aug 26, 2024

@StephennFernandes I am looking forward to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants