forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added extensions and an example extension that merges models
- Loading branch information
1 parent
b865ac8
commit 7e4e660
Showing
14 changed files
with
366 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
import gc | ||
from collections import OrderedDict | ||
from typing import TYPE_CHECKING | ||
from jobs.process import BaseExtensionProcess | ||
from toolkit.config_modules import ModelConfig | ||
from toolkit.stable_diffusion_model import StableDiffusion | ||
from toolkit.train_tools import get_torch_dtype | ||
from tqdm import tqdm | ||
|
||
# Type check imports. Prevents circular imports | ||
if TYPE_CHECKING: | ||
from jobs import ExtensionJob | ||
|
||
|
||
# extend standard config classes to add weight | ||
class ModelInputConfig(ModelConfig): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.weight = kwargs.get('weight', 1.0) | ||
# overwrite default dtype unless user specifies otherwise | ||
# float 32 will give up better precision on the merging functions | ||
self.dtype: str = kwargs.get('dtype', 'float32') | ||
|
||
|
||
def flush(): | ||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
|
||
# this is our main class process | ||
class ExampleMergeModels(BaseExtensionProcess): | ||
def __init__( | ||
self, | ||
process_id: int, | ||
job: 'ExtensionJob', | ||
config: OrderedDict | ||
): | ||
super().__init__(process_id, job, config) | ||
# this is the setup process, do not do process intensive stuff here, just variable setup and | ||
# checking requirements. This is called before the run() function | ||
# no loading models or anything like that, it is just for setting up the process | ||
# all of your process intensive stuff should be done in the run() function | ||
# config will have everything from the process item in the config file | ||
|
||
# convince methods exist on BaseProcess to get config values | ||
# if required is set to true and the value is not found it will throw an error | ||
# you can pass a default value to get_conf() as well if it was not in the config file | ||
# as well as a type to cast the value to | ||
self.save_path = self.get_conf('save_path', required=True) | ||
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype) | ||
self.device = self.get_conf('device', default='cpu', as_type=torch.device) | ||
|
||
# build models to merge list | ||
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list) | ||
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config | ||
# this way you can add methods to it and it is easier to read and code. There are a lot of | ||
# inbuilt config classes located in toolkit.config_modules as well | ||
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge] | ||
# setup is complete. Don't load anything else here, just setup variables and stuff | ||
|
||
# this is the entire run process be sure to call super().run() first | ||
def run(self): | ||
# always call first | ||
super().run() | ||
print(f"Running process: {self.__class__.__name__}") | ||
|
||
# let's adjust our weights first to normalize them so the total is 1.0 | ||
total_weight = sum([model.weight for model in self.models_to_merge]) | ||
weight_adjust = 1.0 / total_weight | ||
for model in self.models_to_merge: | ||
model.weight *= weight_adjust | ||
|
||
output_model: StableDiffusion = None | ||
# let's do the merge, it is a good idea to use tqdm to show progress | ||
for model_config in tqdm(self.models_to_merge, desc="Merging models"): | ||
# setup model class with our helper class | ||
sd_model = StableDiffusion( | ||
device=self.device, | ||
model_config=model_config, | ||
dtype="float32" | ||
) | ||
# load the model | ||
sd_model.load_model() | ||
|
||
# adjust the weight of the text encoder | ||
if isinstance(sd_model.text_encoder, list): | ||
# sdxl model | ||
for text_encoder in sd_model.text_encoder: | ||
for key, value in text_encoder.state_dict().items(): | ||
value *= model_config.weight | ||
else: | ||
# normal model | ||
for key, value in sd_model.text_encoder.state_dict().items(): | ||
value *= model_config.weight | ||
# adjust the weights of the unet | ||
for key, value in sd_model.unet.state_dict().items(): | ||
value *= model_config.weight | ||
|
||
if output_model is None: | ||
# use this one as the base | ||
output_model = sd_model | ||
else: | ||
# merge the models | ||
# text encoder | ||
if isinstance(output_model.text_encoder, list): | ||
# sdxl model | ||
for i, text_encoder in enumerate(output_model.text_encoder): | ||
for key, value in text_encoder.state_dict().items(): | ||
value += sd_model.text_encoder[i].state_dict()[key] | ||
else: | ||
# normal model | ||
for key, value in output_model.text_encoder.state_dict().items(): | ||
value += sd_model.text_encoder.state_dict()[key] | ||
# unet | ||
for key, value in output_model.unet.state_dict().items(): | ||
value += sd_model.unet.state_dict()[key] | ||
|
||
# remove the model to free memory | ||
del sd_model | ||
flush() | ||
|
||
# merge loop is done, let's save the model | ||
print(f"Saving merged model to {self.save_path}") | ||
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype) | ||
print(f"Saved merged model to {self.save_path}") | ||
# do cleanup here | ||
del output_model | ||
flush() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# This is an example extension for custom training. It is great for experimenting with new ideas. | ||
from toolkit.extension import Extension | ||
|
||
|
||
# We make a subclass of Extension | ||
class ExampleMergeExtension(Extension): | ||
# uid must be unique, it is how the extension is identified | ||
uid = "example_merge_extension" | ||
|
||
# name is the name of the extension for printing | ||
name = "Example Merge Extension" | ||
|
||
# This is where your process class is loaded | ||
# keep your imports in here so they don't slow down the rest of the program | ||
@classmethod | ||
def get_process(cls): | ||
# import your process class here so it is only loaded when needed and return it | ||
from .ExampleMergeModels import ExampleMergeModels | ||
return ExampleMergeModels | ||
|
||
|
||
AI_TOOLKIT_EXTENSIONS = [ | ||
# you can put a list of extensions here | ||
ExampleMergeExtension | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
--- | ||
# Always include at least one example config file to show how to use your extension. | ||
# use plenty of comments so users know how to use it and what everything does | ||
|
||
# all extensions will use this job name | ||
job: extension | ||
config: | ||
name: 'my_awesome_merge' | ||
process: | ||
# Put your example processes here. This will be passed | ||
# to your extension process in the config argument. | ||
# the type MUST match your extension uid | ||
- type: "example_merge_extension" | ||
# save path for the merged model | ||
save_path: "output/merge/[name].safetensors" | ||
# save type | ||
dtype: fp16 | ||
# device to run it on | ||
device: cuda:0 | ||
# input models can only be SD1.x and SD2.x models for this example (currently) | ||
models_to_merge: | ||
# weights are relative, total weights will be normalized | ||
# for example. If you have 2 models with weight 1.0, they will | ||
# both be weighted 0.5. If you have 1 model with weight 1.0 and | ||
# another with weight 2.0, the first will be weighted 1/3 and the | ||
# second will be weighted 2/3 | ||
- name_or_path: "input/model1.safetensors" | ||
weight: 1.0 | ||
- name_or_path: "input/model2.safetensors" | ||
weight: 1.0 | ||
- name_or_path: "input/model3.safetensors" | ||
weight: 0.3 | ||
- name_or_path: "input/model4.safetensors" | ||
weight: 1.0 | ||
|
||
|
||
# you can put any information you want here, and it will be saved in the model | ||
# the below is an example. I recommend doing trigger words at a minimum | ||
# in the metadata. The software will include this plus some other information | ||
meta: | ||
name: "[name]" # [name] gets replaced with the name above | ||
description: A short description of your model | ||
version: '0.1' | ||
creator: | ||
name: Your Name | ||
email: [email protected] | ||
website: https://yourwebsite.com | ||
any: All meta data above is arbitrary, it can be whatever you want. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from collections import OrderedDict | ||
from jobs import BaseJob | ||
from toolkit.extension import get_all_extensions_process_dict | ||
|
||
|
||
class ExtensionJob(BaseJob): | ||
|
||
def __init__(self, config: OrderedDict): | ||
super().__init__(config) | ||
self.device = self.get_conf('device', 'cpu') | ||
self.process_dict = get_all_extensions_process_dict() | ||
self.load_processes(self.process_dict) | ||
|
||
def run(self): | ||
super().run() | ||
|
||
print("") | ||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") | ||
|
||
for process in self.process: | ||
process.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from collections import OrderedDict | ||
from typing import ForwardRef | ||
from jobs.process.BaseProcess import BaseProcess | ||
|
||
|
||
class BaseExtensionProcess(BaseProcess): | ||
process_id: int | ||
config: OrderedDict | ||
progress_bar: ForwardRef('tqdm') = None | ||
|
||
def __init__( | ||
self, | ||
process_id: int, | ||
job, | ||
config: OrderedDict | ||
): | ||
super().__init__(process_id, job, config) | ||
|
||
def run(self): | ||
super().run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
import importlib | ||
import pkgutil | ||
from typing import List | ||
|
||
from toolkit.paths import TOOLKIT_ROOT | ||
|
||
|
||
class Extension(object): | ||
"""Base class for extensions. | ||
Extensions are registered with the ExtensionManager, which is | ||
responsible for calling the extension's load() and unload() | ||
methods at the appropriate times. | ||
""" | ||
|
||
name: str = None | ||
uid: str = None | ||
|
||
@classmethod | ||
def get_process(cls): | ||
# extend in subclass | ||
pass | ||
|
||
|
||
def get_all_extensions() -> List[Extension]: | ||
# Get the path of the "extensions" directory | ||
extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions") | ||
|
||
# This will hold the classes from all extension modules | ||
all_extension_classes: List[Extension] = [] | ||
|
||
# Iterate over all directories (i.e., packages) in the "extensions" directory | ||
for (_, name, _) in pkgutil.iter_modules([extensions_dir]): | ||
try: | ||
# Import the module | ||
module = importlib.import_module(f"extensions.{name}") | ||
# Get the value of the AI_TOOLKIT_EXTENSIONS variable | ||
extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) | ||
# Check if the value is a list | ||
if isinstance(extensions, list): | ||
# Iterate over the list and add the classes to the main list | ||
all_extension_classes.extend(extensions) | ||
except ImportError as e: | ||
print(f"Failed to import the {name} module. Error: {str(e)}") | ||
|
||
return all_extension_classes | ||
|
||
|
||
def get_all_extensions_process_dict(): | ||
all_extensions = get_all_extensions() | ||
process_dict = {} | ||
for extension in all_extensions: | ||
process_dict[extension.uid] = extension.get_process() | ||
return process_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.