Skip to content

Commit

Permalink
Added extensions and an example extension that merges models
Browse files Browse the repository at this point in the history
  • Loading branch information
jaretburkett committed Aug 4, 2023
1 parent b865ac8 commit 7e4e660
Show file tree
Hide file tree
Showing 14 changed files with 366 additions and 24 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,6 @@ cython_debug/
!/config/examples
!/config/_PUT_YOUR_CONFIGS_HERE).txt
/output/*
!/output/.gitkeep
!/output/.gitkeep
/extensions/*
!/extensions/example
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ I will post an better tutorial soon.

---

## Extensions!!

You can now make and share custom extensions. That run within this framework and have all the inbuilt tools
available to them. I will probably use this as the primary development method going
forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot
of the existing functionality as well to make everything modular. There is an example extension in the `extensions`
folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully
enough to get you started. To make an extension, just copy that example and replace all the things you need to.


### Model Merger - Example Extension
It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together
as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most
mergers can only do one model at a time and this one will take as many as you want to feed it. There is an
example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`.
and use it like any other config file.

## WIP Tools


Expand Down Expand Up @@ -153,6 +170,12 @@ Just went in and out. It is much worse on smaller faces than shown here.

## Change Log

#### 2021-10-20
- Windows support bug fixes
- Extensions! Added functionality to make and share custom extensions for training, merging, whatever.
check out the example in the `extensions` folder. Read more about that above.
- Model Merging, provided via the example extension.

#### 2021-08-03
Another big refactor to make SD more modular.

Expand Down
129 changes: 129 additions & 0 deletions extensions/example/ExampleMergeModels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import gc
from collections import OrderedDict
from typing import TYPE_CHECKING
from jobs.process import BaseExtensionProcess
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import get_torch_dtype
from tqdm import tqdm

# Type check imports. Prevents circular imports
if TYPE_CHECKING:
from jobs import ExtensionJob


# extend standard config classes to add weight
class ModelInputConfig(ModelConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.weight = kwargs.get('weight', 1.0)
# overwrite default dtype unless user specifies otherwise
# float 32 will give up better precision on the merging functions
self.dtype: str = kwargs.get('dtype', 'float32')


def flush():
torch.cuda.empty_cache()
gc.collect()


# this is our main class process
class ExampleMergeModels(BaseExtensionProcess):
def __init__(
self,
process_id: int,
job: 'ExtensionJob',
config: OrderedDict
):
super().__init__(process_id, job, config)
# this is the setup process, do not do process intensive stuff here, just variable setup and
# checking requirements. This is called before the run() function
# no loading models or anything like that, it is just for setting up the process
# all of your process intensive stuff should be done in the run() function
# config will have everything from the process item in the config file

# convince methods exist on BaseProcess to get config values
# if required is set to true and the value is not found it will throw an error
# you can pass a default value to get_conf() as well if it was not in the config file
# as well as a type to cast the value to
self.save_path = self.get_conf('save_path', required=True)
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
self.device = self.get_conf('device', default='cpu', as_type=torch.device)

# build models to merge list
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
# this way you can add methods to it and it is easier to read and code. There are a lot of
# inbuilt config classes located in toolkit.config_modules as well
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
# setup is complete. Don't load anything else here, just setup variables and stuff

# this is the entire run process be sure to call super().run() first
def run(self):
# always call first
super().run()
print(f"Running process: {self.__class__.__name__}")

# let's adjust our weights first to normalize them so the total is 1.0
total_weight = sum([model.weight for model in self.models_to_merge])
weight_adjust = 1.0 / total_weight
for model in self.models_to_merge:
model.weight *= weight_adjust

output_model: StableDiffusion = None
# let's do the merge, it is a good idea to use tqdm to show progress
for model_config in tqdm(self.models_to_merge, desc="Merging models"):
# setup model class with our helper class
sd_model = StableDiffusion(
device=self.device,
model_config=model_config,
dtype="float32"
)
# load the model
sd_model.load_model()

# adjust the weight of the text encoder
if isinstance(sd_model.text_encoder, list):
# sdxl model
for text_encoder in sd_model.text_encoder:
for key, value in text_encoder.state_dict().items():
value *= model_config.weight
else:
# normal model
for key, value in sd_model.text_encoder.state_dict().items():
value *= model_config.weight
# adjust the weights of the unet
for key, value in sd_model.unet.state_dict().items():
value *= model_config.weight

if output_model is None:
# use this one as the base
output_model = sd_model
else:
# merge the models
# text encoder
if isinstance(output_model.text_encoder, list):
# sdxl model
for i, text_encoder in enumerate(output_model.text_encoder):
for key, value in text_encoder.state_dict().items():
value += sd_model.text_encoder[i].state_dict()[key]
else:
# normal model
for key, value in output_model.text_encoder.state_dict().items():
value += sd_model.text_encoder.state_dict()[key]
# unet
for key, value in output_model.unet.state_dict().items():
value += sd_model.unet.state_dict()[key]

# remove the model to free memory
del sd_model
flush()

# merge loop is done, let's save the model
print(f"Saving merged model to {self.save_path}")
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
print(f"Saved merged model to {self.save_path}")
# do cleanup here
del output_model
flush()
25 changes: 25 additions & 0 deletions extensions/example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# This is an example extension for custom training. It is great for experimenting with new ideas.
from toolkit.extension import Extension


# We make a subclass of Extension
class ExampleMergeExtension(Extension):
# uid must be unique, it is how the extension is identified
uid = "example_merge_extension"

# name is the name of the extension for printing
name = "Example Merge Extension"

# This is where your process class is loaded
# keep your imports in here so they don't slow down the rest of the program
@classmethod
def get_process(cls):
# import your process class here so it is only loaded when needed and return it
from .ExampleMergeModels import ExampleMergeModels
return ExampleMergeModels


AI_TOOLKIT_EXTENSIONS = [
# you can put a list of extensions here
ExampleMergeExtension
]
48 changes: 48 additions & 0 deletions extensions/example/config/config.example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
# Always include at least one example config file to show how to use your extension.
# use plenty of comments so users know how to use it and what everything does

# all extensions will use this job name
job: extension
config:
name: 'my_awesome_merge'
process:
# Put your example processes here. This will be passed
# to your extension process in the config argument.
# the type MUST match your extension uid
- type: "example_merge_extension"
# save path for the merged model
save_path: "output/merge/[name].safetensors"
# save type
dtype: fp16
# device to run it on
device: cuda:0
# input models can only be SD1.x and SD2.x models for this example (currently)
models_to_merge:
# weights are relative, total weights will be normalized
# for example. If you have 2 models with weight 1.0, they will
# both be weighted 0.5. If you have 1 model with weight 1.0 and
# another with weight 2.0, the first will be weighted 1/3 and the
# second will be weighted 2/3
- name_or_path: "input/model1.safetensors"
weight: 1.0
- name_or_path: "input/model2.safetensors"
weight: 1.0
- name_or_path: "input/model3.safetensors"
weight: 0.3
- name_or_path: "input/model4.safetensors"
weight: 1.0


# you can put any information you want here, and it will be saved in the model
# the below is an example. I recommend doing trigger words at a minimum
# in the metadata. The software will include this plus some other information
meta:
name: "[name]" # [name] gets replaced with the name above
description: A short description of your model
version: '0.1'
creator:
name: Your Name
email: [email protected]
website: https://yourwebsite.com
any: All meta data above is arbitrary, it can be whatever you want.
2 changes: 1 addition & 1 deletion info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.2"
v["version"] = "0.0.3"

software_meta = v
6 changes: 5 additions & 1 deletion jobs/BaseJob.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ def load_processes(self, process_dict: dict):

# check if dict key is process type
if process['type'] in process_dict:
ProcessClass = getattr(module, process_dict[process['type']])
if isinstance(process_dict[process['type']], str):
ProcessClass = getattr(module, process_dict[process['type']])
else:
# it is the class
ProcessClass = process_dict[process['type']]
self.process.append(ProcessClass(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
Expand Down
21 changes: 21 additions & 0 deletions jobs/ExtensionJob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from collections import OrderedDict
from jobs import BaseJob
from toolkit.extension import get_all_extensions_process_dict


class ExtensionJob(BaseJob):

def __init__(self, config: OrderedDict):
super().__init__(config)
self.device = self.get_conf('device', 'cpu')
self.process_dict = get_all_extensions_process_dict()
self.load_processes(self.process_dict)

def run(self):
super().run()

print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")

for process in self.process:
process.run()
1 change: 1 addition & 0 deletions jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .MergeJob import MergeJob
from .ModJob import ModJob
from .GenerateJob import GenerateJob
from .ExtensionJob import ExtensionJob
20 changes: 20 additions & 0 deletions jobs/process/BaseExtensionProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from collections import OrderedDict
from typing import ForwardRef
from jobs.process.BaseProcess import BaseProcess


class BaseExtensionProcess(BaseProcess):
process_id: int
config: OrderedDict
progress_bar: ForwardRef('tqdm') = None

def __init__(
self,
process_id: int,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)

def run(self):
super().run()
1 change: 1 addition & 0 deletions jobs/process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .TrainSDRescaleProcess import TrainSDRescaleProcess
from .ModRescaleLoraProcess import ModRescaleLoraProcess
from .GenerateProcess import GenerateProcess
from .BaseExtensionProcess import BaseExtensionProcess
56 changes: 56 additions & 0 deletions toolkit/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import importlib
import pkgutil
from typing import List

from toolkit.paths import TOOLKIT_ROOT


class Extension(object):
"""Base class for extensions.
Extensions are registered with the ExtensionManager, which is
responsible for calling the extension's load() and unload()
methods at the appropriate times.
"""

name: str = None
uid: str = None

@classmethod
def get_process(cls):
# extend in subclass
pass


def get_all_extensions() -> List[Extension]:
# Get the path of the "extensions" directory
extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions")

# This will hold the classes from all extension modules
all_extension_classes: List[Extension] = []

# Iterate over all directories (i.e., packages) in the "extensions" directory
for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
try:
# Import the module
module = importlib.import_module(f"extensions.{name}")
# Get the value of the AI_TOOLKIT_EXTENSIONS variable
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
# Check if the value is a list
if isinstance(extensions, list):
# Iterate over the list and add the classes to the main list
all_extension_classes.extend(extensions)
except ImportError as e:
print(f"Failed to import the {name} module. Error: {str(e)}")

return all_extension_classes


def get_all_extensions_process_dict():
all_extensions = get_all_extensions()
process_dict = {}
for extension in all_extensions:
process_dict[extension.uid] = extension.get_process()
return process_dict
3 changes: 3 additions & 0 deletions toolkit/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def get_job(config_path, name=None):
if job == 'generate':
from jobs import GenerateJob
return GenerateJob(config)
if job == 'extension':
from jobs import ExtensionJob
return ExtensionJob(config)

# elif job == 'train':
# from jobs import TrainJob
Expand Down
Loading

0 comments on commit 7e4e660

Please sign in to comment.