Skip to content

Commit

Permalink
fix LoRA training script
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Sep 1, 2023
1 parent 9f6733d commit 04fe0c5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions scripts/training/finetune-ldm-lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
from loguru import logger
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS, lora_targets
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -79,9 +79,10 @@ def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:

for model_name in MODELS:
model = getattr(trainer, model_name)
model_targets: list[LoraTarget] = getattr(lora_config, f"{model_name}_targets")
adapter = LoraAdapter[type(model)](
model,
sub_targets=getattr(lora_config, f"{model_name}_targets"),
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:
Expand Down

0 comments on commit 04fe0c5

Please sign in to comment.