Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove HF Dataset from Base Config #160

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/finetune-textual-inversion.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dropout_probability = 0
use_gyro_dropout = false

[dataset]
hf_repo = "acme/cat-toy"
hf_repo = "acme/images"
revision = "main"
horizontal_flip = true
random_crop = true
Expand Down
2 changes: 2 additions & 0 deletions scripts/training/finetune-ldm-lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import MODELS, LoraAdapter, LoraTarget, lora_targets
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
Expand Down Expand Up @@ -50,6 +51,7 @@ def process_caption(self, caption: str) -> str:


class LoraLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
lora: LoraConfig

Expand Down
2 changes: 2 additions & 0 deletions scripts/training/finetune-ldm-textual-inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder, TokenEncoder
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.training_utils.callback import Callback
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from refiners.training_utils.latent_diffusion import (
FinetuneLatentDiffusionConfig,
LatentDiffusionConfig,
Expand Down Expand Up @@ -112,6 +113,7 @@ def apply_textual_inversion_to_target(self, text_encoder: CLIPTextEncoder) -> No


class TextualInversionLatentDiffusionConfig(FinetuneLatentDiffusionConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
textual_inversion: TextualInversionConfig

Expand Down
12 changes: 0 additions & 12 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,6 @@ class WandbConfig(BaseModel):
notes: str | None = None


class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str = "finegrain/unsplash-dummy"
revision: str = "main"
split: str = "train"
horizontal_flip: bool = False
random_crop: bool = True
use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576


class CheckpointingConfig(BaseModel):
save_folder: Path | None = None
save_interval: TimeValue = {"number": 1, "unit": TimeUnit.EPOCH}
Expand All @@ -237,7 +226,6 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig
scheduler: SchedulerConfig
dropout: DropoutConfig
dataset: HuggingfaceDatasetConfig
checkpointing: CheckpointingConfig

@classmethod
Expand Down
12 changes: 12 additions & 0 deletions src/refiners/training_utils/huggingface_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Generic, Protocol, TypeVar, cast

from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
from pydantic import BaseModel # type: ignore

__all__ = ["load_hf_dataset", "HuggingfaceDataset"]

Expand All @@ -22,3 +23,14 @@ def load_hf_dataset(
verification_mode = VerificationMode.BASIC_CHECKS if use_verification else VerificationMode.NO_CHECKS
dataset = _load_dataset(path=path, revision=revision, split=split, verification_mode=verification_mode)
return cast(HuggingfaceDataset[Any], dataset)


class HuggingfaceDatasetConfig(BaseModel):
hf_repo: str
revision: str = "main"
split: str = "train"
horizontal_flip: bool = False
random_crop: bool = True
use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576
3 changes: 2 additions & 1 deletion src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.callback import Callback
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, load_hf_dataset
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset
from refiners.training_utils.trainer import Trainer
from refiners.training_utils.wandb import WandbLoggable

Expand All @@ -44,6 +44,7 @@ class TestDiffusionConfig(BaseModel):


class FinetuneLatentDiffusionConfig(BaseConfig):
dataset: HuggingfaceDatasetConfig
latent_diffusion: LatentDiffusionConfig
test_diffusion: TestDiffusionConfig

Expand Down