From 7e4e660663067e71395e517ca3dc141f645a56f7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 4 Aug 2023 09:37:24 -0600 Subject: [PATCH] Added extensions and an example extension that merges models --- .gitignore | 4 +- README.md | 23 ++++ extensions/example/ExampleMergeModels.py | 129 ++++++++++++++++++ extensions/example/__init__.py | 25 ++++ extensions/example/config/config.example.yaml | 48 +++++++ info.py | 2 +- jobs/BaseJob.py | 6 +- jobs/ExtensionJob.py | 21 +++ jobs/__init__.py | 1 + jobs/process/BaseExtensionProcess.py | 20 +++ jobs/process/__init__.py | 1 + toolkit/extension.py | 56 ++++++++ toolkit/job.py | 3 + toolkit/stable_diffusion_model.py | 51 ++++--- 14 files changed, 366 insertions(+), 24 deletions(-) create mode 100644 extensions/example/ExampleMergeModels.py create mode 100644 extensions/example/__init__.py create mode 100644 extensions/example/config/config.example.yaml create mode 100644 jobs/ExtensionJob.py create mode 100644 jobs/process/BaseExtensionProcess.py create mode 100644 toolkit/extension.py diff --git a/.gitignore b/.gitignore index f1480418..5a3ba0ed 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,6 @@ cython_debug/ !/config/examples !/config/_PUT_YOUR_CONFIGS_HERE).txt /output/* -!/output/.gitkeep \ No newline at end of file +!/output/.gitkeep +/extensions/* +!/extensions/example \ No newline at end of file diff --git a/README.md b/README.md index 50f7ee8f..0885e733 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,23 @@ I will post an better tutorial soon. --- +## Extensions!! + +You can now make and share custom extensions. That run within this framework and have all the inbuilt tools +available to them. I will probably use this as the primary development method going +forward so I dont keep adding and adding more and more features to this base repo. I will likely migrate a lot +of the existing functionality as well to make everything modular. There is an example extension in the `extensions` +folder that shows how to make a model merger extension. All of the code is heavily documented which is hopefully +enough to get you started. To make an extension, just copy that example and replace all the things you need to. + + +### Model Merger - Example Extension +It is located in the `extensions` folder. It is a fully finctional model merger that can merge as many models together +as you want. It is a good example of how to make an extension, but is also a pretty useful feature as well since most +mergers can only do one model at a time and this one will take as many as you want to feed it. There is an +example config file in there, just copy that to your `config` folder and rename it to `whatever_you_want.yml`. +and use it like any other config file. + ## WIP Tools @@ -153,6 +170,12 @@ Just went in and out. It is much worse on smaller faces than shown here. ## Change Log +#### 2021-10-20 + - Windows support bug fixes + - Extensions! Added functionality to make and share custom extensions for training, merging, whatever. +check out the example in the `extensions` folder. Read more about that above. + - Model Merging, provided via the example extension. + #### 2021-08-03 Another big refactor to make SD more modular. diff --git a/extensions/example/ExampleMergeModels.py b/extensions/example/ExampleMergeModels.py new file mode 100644 index 00000000..162d514c --- /dev/null +++ b/extensions/example/ExampleMergeModels.py @@ -0,0 +1,129 @@ +import torch +import gc +from collections import OrderedDict +from typing import TYPE_CHECKING +from jobs.process import BaseExtensionProcess +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import get_torch_dtype +from tqdm import tqdm + +# Type check imports. Prevents circular imports +if TYPE_CHECKING: + from jobs import ExtensionJob + + +# extend standard config classes to add weight +class ModelInputConfig(ModelConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.weight = kwargs.get('weight', 1.0) + # overwrite default dtype unless user specifies otherwise + # float 32 will give up better precision on the merging functions + self.dtype: str = kwargs.get('dtype', 'float32') + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +# this is our main class process +class ExampleMergeModels(BaseExtensionProcess): + def __init__( + self, + process_id: int, + job: 'ExtensionJob', + config: OrderedDict + ): + super().__init__(process_id, job, config) + # this is the setup process, do not do process intensive stuff here, just variable setup and + # checking requirements. This is called before the run() function + # no loading models or anything like that, it is just for setting up the process + # all of your process intensive stuff should be done in the run() function + # config will have everything from the process item in the config file + + # convince methods exist on BaseProcess to get config values + # if required is set to true and the value is not found it will throw an error + # you can pass a default value to get_conf() as well if it was not in the config file + # as well as a type to cast the value to + self.save_path = self.get_conf('save_path', required=True) + self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype) + self.device = self.get_conf('device', default='cpu', as_type=torch.device) + + # build models to merge list + models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list) + # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config + # this way you can add methods to it and it is easier to read and code. There are a lot of + # inbuilt config classes located in toolkit.config_modules as well + self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge] + # setup is complete. Don't load anything else here, just setup variables and stuff + + # this is the entire run process be sure to call super().run() first + def run(self): + # always call first + super().run() + print(f"Running process: {self.__class__.__name__}") + + # let's adjust our weights first to normalize them so the total is 1.0 + total_weight = sum([model.weight for model in self.models_to_merge]) + weight_adjust = 1.0 / total_weight + for model in self.models_to_merge: + model.weight *= weight_adjust + + output_model: StableDiffusion = None + # let's do the merge, it is a good idea to use tqdm to show progress + for model_config in tqdm(self.models_to_merge, desc="Merging models"): + # setup model class with our helper class + sd_model = StableDiffusion( + device=self.device, + model_config=model_config, + dtype="float32" + ) + # load the model + sd_model.load_model() + + # adjust the weight of the text encoder + if isinstance(sd_model.text_encoder, list): + # sdxl model + for text_encoder in sd_model.text_encoder: + for key, value in text_encoder.state_dict().items(): + value *= model_config.weight + else: + # normal model + for key, value in sd_model.text_encoder.state_dict().items(): + value *= model_config.weight + # adjust the weights of the unet + for key, value in sd_model.unet.state_dict().items(): + value *= model_config.weight + + if output_model is None: + # use this one as the base + output_model = sd_model + else: + # merge the models + # text encoder + if isinstance(output_model.text_encoder, list): + # sdxl model + for i, text_encoder in enumerate(output_model.text_encoder): + for key, value in text_encoder.state_dict().items(): + value += sd_model.text_encoder[i].state_dict()[key] + else: + # normal model + for key, value in output_model.text_encoder.state_dict().items(): + value += sd_model.text_encoder.state_dict()[key] + # unet + for key, value in output_model.unet.state_dict().items(): + value += sd_model.unet.state_dict()[key] + + # remove the model to free memory + del sd_model + flush() + + # merge loop is done, let's save the model + print(f"Saving merged model to {self.save_path}") + output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype) + print(f"Saved merged model to {self.save_path}") + # do cleanup here + del output_model + flush() diff --git a/extensions/example/__init__.py b/extensions/example/__init__.py new file mode 100644 index 00000000..34f348f1 --- /dev/null +++ b/extensions/example/__init__.py @@ -0,0 +1,25 @@ +# This is an example extension for custom training. It is great for experimenting with new ideas. +from toolkit.extension import Extension + + +# We make a subclass of Extension +class ExampleMergeExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "example_merge_extension" + + # name is the name of the extension for printing + name = "Example Merge Extension" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .ExampleMergeModels import ExampleMergeModels + return ExampleMergeModels + + +AI_TOOLKIT_EXTENSIONS = [ + # you can put a list of extensions here + ExampleMergeExtension +] diff --git a/extensions/example/config/config.example.yaml b/extensions/example/config/config.example.yaml new file mode 100644 index 00000000..abed03fd --- /dev/null +++ b/extensions/example/config/config.example.yaml @@ -0,0 +1,48 @@ +--- +# Always include at least one example config file to show how to use your extension. +# use plenty of comments so users know how to use it and what everything does + +# all extensions will use this job name +job: extension +config: + name: 'my_awesome_merge' + process: + # Put your example processes here. This will be passed + # to your extension process in the config argument. + # the type MUST match your extension uid + - type: "example_merge_extension" + # save path for the merged model + save_path: "output/merge/[name].safetensors" + # save type + dtype: fp16 + # device to run it on + device: cuda:0 + # input models can only be SD1.x and SD2.x models for this example (currently) + models_to_merge: + # weights are relative, total weights will be normalized + # for example. If you have 2 models with weight 1.0, they will + # both be weighted 0.5. If you have 1 model with weight 1.0 and + # another with weight 2.0, the first will be weighted 1/3 and the + # second will be weighted 2/3 + - name_or_path: "input/model1.safetensors" + weight: 1.0 + - name_or_path: "input/model2.safetensors" + weight: 1.0 + - name_or_path: "input/model3.safetensors" + weight: 0.3 + - name_or_path: "input/model4.safetensors" + weight: 1.0 + + +# you can put any information you want here, and it will be saved in the model +# the below is an example. I recommend doing trigger words at a minimum +# in the metadata. The software will include this plus some other information +meta: + name: "[name]" # [name] gets replaced with the name above + description: A short description of your model + version: '0.1' + creator: + name: Your Name + email: your@email.com + website: https://yourwebsite.com + any: All meta data above is arbitrary, it can be whatever you want. \ No newline at end of file diff --git a/info.py b/info.py index 85113e00..b40d9344 100644 --- a/info.py +++ b/info.py @@ -3,6 +3,6 @@ v = OrderedDict() v["name"] = "ai-toolkit" v["repo"] = "https://github.com/ostris/ai-toolkit" -v["version"] = "0.0.2" +v["version"] = "0.0.3" software_meta = v diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py index f29a29d1..6027e71f 100644 --- a/jobs/BaseJob.py +++ b/jobs/BaseJob.py @@ -60,7 +60,11 @@ def load_processes(self, process_dict: dict): # check if dict key is process type if process['type'] in process_dict: - ProcessClass = getattr(module, process_dict[process['type']]) + if isinstance(process_dict[process['type']], str): + ProcessClass = getattr(module, process_dict[process['type']]) + else: + # it is the class + ProcessClass = process_dict[process['type']] self.process.append(ProcessClass(i, self, process)) else: raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') diff --git a/jobs/ExtensionJob.py b/jobs/ExtensionJob.py new file mode 100644 index 00000000..e1ddc965 --- /dev/null +++ b/jobs/ExtensionJob.py @@ -0,0 +1,21 @@ +from collections import OrderedDict +from jobs import BaseJob +from toolkit.extension import get_all_extensions_process_dict + + +class ExtensionJob(BaseJob): + + def __init__(self, config: OrderedDict): + super().__init__(config) + self.device = self.get_conf('device', 'cpu') + self.process_dict = get_all_extensions_process_dict() + self.load_processes(self.process_dict) + + def run(self): + super().run() + + print("") + print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") + + for process in self.process: + process.run() diff --git a/jobs/__init__.py b/jobs/__init__.py index 9f232a52..7da6c22b 100644 --- a/jobs/__init__.py +++ b/jobs/__init__.py @@ -4,3 +4,4 @@ from .MergeJob import MergeJob from .ModJob import ModJob from .GenerateJob import GenerateJob +from .ExtensionJob import ExtensionJob diff --git a/jobs/process/BaseExtensionProcess.py b/jobs/process/BaseExtensionProcess.py new file mode 100644 index 00000000..d6185633 --- /dev/null +++ b/jobs/process/BaseExtensionProcess.py @@ -0,0 +1,20 @@ +from collections import OrderedDict +from typing import ForwardRef +from jobs.process.BaseProcess import BaseProcess + + +class BaseExtensionProcess(BaseProcess): + process_id: int + config: OrderedDict + progress_bar: ForwardRef('tqdm') = None + + def __init__( + self, + process_id: int, + job, + config: OrderedDict + ): + super().__init__(process_id, job, config) + + def run(self): + super().run() diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index a227a7a7..4c7c7660 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -11,3 +11,4 @@ from .TrainSDRescaleProcess import TrainSDRescaleProcess from .ModRescaleLoraProcess import ModRescaleLoraProcess from .GenerateProcess import GenerateProcess +from .BaseExtensionProcess import BaseExtensionProcess diff --git a/toolkit/extension.py b/toolkit/extension.py new file mode 100644 index 00000000..cb47329d --- /dev/null +++ b/toolkit/extension.py @@ -0,0 +1,56 @@ +import os +import importlib +import pkgutil +from typing import List + +from toolkit.paths import TOOLKIT_ROOT + + +class Extension(object): + """Base class for extensions. + + Extensions are registered with the ExtensionManager, which is + responsible for calling the extension's load() and unload() + methods at the appropriate times. + + """ + + name: str = None + uid: str = None + + @classmethod + def get_process(cls): + # extend in subclass + pass + + +def get_all_extensions() -> List[Extension]: + # Get the path of the "extensions" directory + extensions_dir = os.path.join(TOOLKIT_ROOT, "extensions") + + # This will hold the classes from all extension modules + all_extension_classes: List[Extension] = [] + + # Iterate over all directories (i.e., packages) in the "extensions" directory + for (_, name, _) in pkgutil.iter_modules([extensions_dir]): + try: + # Import the module + module = importlib.import_module(f"extensions.{name}") + # Get the value of the AI_TOOLKIT_EXTENSIONS variable + extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None) + # Check if the value is a list + if isinstance(extensions, list): + # Iterate over the list and add the classes to the main list + all_extension_classes.extend(extensions) + except ImportError as e: + print(f"Failed to import the {name} module. Error: {str(e)}") + + return all_extension_classes + + +def get_all_extensions_process_dict(): + all_extensions = get_all_extensions() + process_dict = {} + for extension in all_extensions: + process_dict[extension.uid] = extension.get_process() + return process_dict diff --git a/toolkit/job.py b/toolkit/job.py index 60752740..a828e3c8 100644 --- a/toolkit/job.py +++ b/toolkit/job.py @@ -19,6 +19,9 @@ def get_job(config_path, name=None): if job == 'generate': from jobs import GenerateJob return GenerateJob(config) + if job == 'extension': + from jobs import ExtensionJob + return ExtensionJob(config) # elif job == 'train': # from jobs import TrainJob diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5fb253d6..2e924d19 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -8,7 +8,10 @@ from safetensors.torch import save_file from tqdm import tqdm +from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ + convert_vae_state_dict from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT from toolkit.train_tools import get_torch_dtype, apply_noise_offset @@ -161,6 +164,7 @@ def load_model(self): scheduler_type='dpm', device=self.device_torch, load_safety_checker=False, + requires_safety_checker=False, ).to(self.device_torch) pipe.register_to_config(requires_safety_checker=False) text_encoder = pipe.text_encoder @@ -468,17 +472,16 @@ def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: ) def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): - # todo see what logit scale is - if self.is_xl: - - state_dict = {} + state_dict = {} - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype)) - state_dict[key] = v + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype)) + state_dict[key] = v + # todo see what logit scale is + if self.is_xl: # Convert the UNet model update_sd("model.diffusion_model.", self.unet.state_dict()) @@ -488,19 +491,25 @@ def update_sd(prefix, sd): text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale) update_sd("conditioner.embedders.1.model.", text_enc2_dict) + else: + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(self.is_v2, self.unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if self.is_v2: + make_dummy = True + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(self.text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = self.text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + # Convert the VAE + if self.vae is not None: vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict()) update_sd("first_stage_model.", vae_dict) - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {"state_dict": state_dict} - - if model_util.is_safetensors(output_file): - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file, meta) - - return key_count - else: - raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet") + # prepare metadata + meta = get_meta_for_safetensors(meta) + save_file(state_dict, output_file, metadata=meta)