forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reworked so everything is in classes for easy expansion. Single entry…
… point for all config files now.
- Loading branch information
1 parent
27df03a
commit 37354b0
Showing
16 changed files
with
424 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,32 @@ | ||
{ | ||
"job": "extract", | ||
"config": { | ||
"name": "name_of_your_model", | ||
"base_model": "/path/to/base/model", | ||
"extract_model": "/path/to/model/to/extract", | ||
"output_folder": "/path/to/output/folder", | ||
"is_v2": false, | ||
"device": "cpu", | ||
"use_sparse_bias": false, | ||
"sparsity": 0.98, | ||
"disable_cp": false, | ||
"process": [ | ||
{ | ||
"filename":"[name]_64_32.safetensors", | ||
"type": "locon", | ||
"mode": "fixed", | ||
"linear_dim": 64, | ||
"conv_dim": 32 | ||
"linear": 64, | ||
"conv": 32 | ||
}, | ||
{ | ||
"output_path": "/absolute/path/for/this/output.safetensors", | ||
"type": "locon", | ||
"mode": "ratio", | ||
"linear_ratio": 0.2, | ||
"conv_ratio": 0.2 | ||
"linear": 0.2, | ||
"conv": 0.2 | ||
}, | ||
{ | ||
"type": "locon", | ||
"mode": "quantile", | ||
"linear_quantile": 0.5, | ||
"conv_quantile": 0.5 | ||
"linear": 0.5, | ||
"conv": 0.5 | ||
} | ||
] | ||
}, | ||
|
@@ -41,6 +44,7 @@ | |
"name": "Your Name", | ||
"email": "[email protected]", | ||
"website": "https://yourwebsite.com" | ||
} | ||
}, | ||
"any": "All meta data above is arbitrary, it can be whatever you want." | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from collections import OrderedDict | ||
|
||
v = OrderedDict() | ||
v["name"] = "ai-toolkit" | ||
v["repo"] = "https://github.com/ostris/ai-toolkit" | ||
v["version"] = "0.0.1" | ||
|
||
software_meta = v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from collections import OrderedDict | ||
|
||
|
||
class BaseJob: | ||
config: OrderedDict | ||
job: str | ||
name: str | ||
meta: OrderedDict | ||
|
||
def __init__(self, config: OrderedDict): | ||
if not config: | ||
raise ValueError('config is required') | ||
|
||
self.config = config['config'] | ||
self.job = config['job'] | ||
self.name = self.get_conf('name', required=True) | ||
if 'meta' in config: | ||
self.meta = config['meta'] | ||
else: | ||
self.meta = OrderedDict() | ||
|
||
def get_conf(self, key, default=None, required=False): | ||
if key in self.config: | ||
return self.config[key] | ||
elif required: | ||
raise ValueError(f'config file error. Missing "config.{key}" key') | ||
else: | ||
return default | ||
|
||
def run(self): | ||
print("") | ||
print(f"#############################################") | ||
print(f"# Running job: {self.name}") | ||
print(f"#############################################") | ||
print("") | ||
# implement in child class | ||
# be sure to call super().run() first | ||
pass | ||
|
||
def cleanup(self): | ||
# if you implement this in child clas, | ||
# be sure to call super().cleanup() LAST | ||
del self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint | ||
from .BaseJob import BaseJob | ||
from collections import OrderedDict | ||
from typing import List | ||
|
||
from jobs.process import BaseExtractProcess | ||
|
||
|
||
class ExtractJob(BaseJob): | ||
process: List[BaseExtractProcess] | ||
|
||
def __init__(self, config: OrderedDict): | ||
super().__init__(config) | ||
self.base_model_path = self.get_conf('base_model', required=True) | ||
self.base_model = None | ||
self.extract_model_path = self.get_conf('extract_model', required=True) | ||
self.extract_model = None | ||
self.output_folder = self.get_conf('output_folder', required=True) | ||
self.is_v2 = self.get_conf('is_v2', False) | ||
self.device = self.get_conf('device', 'cpu') | ||
|
||
if 'process' not in self.config: | ||
raise ValueError('config file is invalid. Missing "config.process" key') | ||
if len(self.config['process']) == 0: | ||
raise ValueError('config file is invalid. "config.process" must be a list of processes') | ||
|
||
# add the processes | ||
self.process = [] | ||
for i, process in enumerate(self.config['process']): | ||
if 'type' not in process: | ||
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key') | ||
if process['type'] == 'locon': | ||
from jobs.process import LoconExtractProcess | ||
self.process.append(LoconExtractProcess(i, self, process)) | ||
else: | ||
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}') | ||
|
||
def run(self): | ||
super().run() | ||
# load models | ||
print(f"Loading models for extraction") | ||
print(f" - Loading base model: {self.base_model_path}") | ||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) | ||
|
||
print(f" - Loading extract model: {self.extract_model_path}") | ||
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path) | ||
|
||
print("") | ||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") | ||
|
||
for process in self.process: | ||
process.run() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .BaseJob import BaseJob | ||
from .ExtractJob import ExtractJob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
from collections import OrderedDict | ||
|
||
from safetensors.torch import save_file | ||
|
||
from jobs import ExtractJob | ||
from jobs.process.BaseProcess import BaseProcess | ||
from toolkit.metadata import get_meta_for_safetensors | ||
|
||
|
||
class BaseExtractProcess(BaseProcess): | ||
job: ExtractJob | ||
process_id: int | ||
config: OrderedDict | ||
output_folder: str | ||
output_filename: str | ||
output_path: str | ||
|
||
def __init__( | ||
self, | ||
process_id: int, | ||
job: ExtractJob, | ||
config: OrderedDict | ||
): | ||
super().__init__(process_id, job, config) | ||
self.process_id = process_id | ||
self.job = job | ||
self.config = config | ||
|
||
def run(self): | ||
# here instead of init because child init needs to go first | ||
self.output_path = self.get_output_path() | ||
# implement in child class | ||
# be sure to call super().run() first | ||
pass | ||
|
||
# you can override this in the child class if you want | ||
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this | ||
def get_output_path(self, prefix=None, suffix=None): | ||
config_output_path = self.get_conf('output_path', None) | ||
config_filename = self.get_conf('filename', None) | ||
# replace [name] with name | ||
|
||
if config_output_path is not None: | ||
config_output_path = config_output_path.replace('[name]', self.job.name) | ||
return config_output_path | ||
|
||
if config_output_path is None and config_filename is not None: | ||
# build the output path from the output folder and filename | ||
return os.path.join(self.job.output_folder, config_filename) | ||
|
||
# build our own | ||
|
||
if suffix is None: | ||
# we will just add process it to the end of the filename if there is more than one process | ||
# and no other suffix was given | ||
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else '' | ||
|
||
if prefix is None: | ||
prefix = '' | ||
|
||
output_filename = f"{prefix}{self.output_filename}{suffix}" | ||
|
||
return os.path.join(self.job.output_folder, output_filename) | ||
|
||
def save(self, state_dict): | ||
# prepare meta | ||
save_meta = get_meta_for_safetensors(self.meta, self.job.name) | ||
|
||
# save | ||
os.makedirs(os.path.dirname(self.output_path), exist_ok=True) | ||
|
||
# having issues with meta | ||
save_file(state_dict, self.output_path, save_meta) | ||
|
||
print(f"Saved to {self.output_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import copy | ||
import json | ||
from collections import OrderedDict | ||
|
||
from jobs import BaseJob | ||
|
||
|
||
class BaseProcess: | ||
meta: OrderedDict | ||
|
||
def __init__( | ||
self, | ||
process_id: int, | ||
job: BaseJob, | ||
config: OrderedDict | ||
): | ||
self.process_id = process_id | ||
self.job = job | ||
self.config = config | ||
self.meta = copy.deepcopy(self.job.meta) | ||
|
||
def get_conf(self, key, default=None, required=False, as_type=None): | ||
if key in self.config: | ||
value = self.config[key] | ||
if as_type is not None: | ||
value = as_type(value) | ||
return value | ||
elif required: | ||
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key') | ||
else: | ||
if as_type is not None: | ||
return as_type(default) | ||
return default | ||
|
||
def run(self): | ||
# implement in child class | ||
# be sure to call super().run() first incase something is added here | ||
pass | ||
|
||
def add_meta(self, additional_meta: OrderedDict): | ||
self.meta.update(additional_meta) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from collections import OrderedDict | ||
from toolkit.lycoris_utils import extract_diff | ||
from .BaseExtractProcess import BaseExtractProcess | ||
from .. import ExtractJob | ||
|
||
mode_dict = { | ||
'fixed': { | ||
'linear': 64, | ||
'conv': 32, | ||
'type': int | ||
}, | ||
'threshold': { | ||
'linear': 0, | ||
'conv': 0, | ||
'type': float | ||
}, | ||
'ratio': { | ||
'linear': 0.5, | ||
'conv': 0.5, | ||
'type': float | ||
}, | ||
'quantile': { | ||
'linear': 0.5, | ||
'conv': 0.5, | ||
'type': float | ||
} | ||
} | ||
|
||
|
||
class LoconExtractProcess(BaseExtractProcess): | ||
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict): | ||
super().__init__(process_id, job, config) | ||
self.mode = self.get_conf('mode', 'fixed') | ||
self.use_sparse_bias = self.get_conf('use_sparse_bias', False) | ||
self.sparsity = self.get_conf('sparsity', 0.98) | ||
self.disable_cp = self.get_conf('disable_cp', False) | ||
|
||
# set modes | ||
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']: | ||
raise ValueError(f"Unknown mode: {self.mode}") | ||
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type']) | ||
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type']) | ||
|
||
def run(self): | ||
super().run() | ||
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}") | ||
|
||
state_dict, extract_diff_meta = extract_diff( | ||
self.job.base_model, | ||
self.job.extract_model, | ||
self.mode, | ||
self.linear_param, | ||
self.conv_param, | ||
self.job.device, | ||
self.use_sparse_bias, | ||
self.sparsity, | ||
not self.disable_cp | ||
) | ||
|
||
self.add_meta(extract_diff_meta) | ||
self.save(state_dict) | ||
|
||
def get_output_path(self, prefix=None, suffix=None): | ||
if suffix is None: | ||
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}" | ||
return super().get_output_path(prefix, suffix) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .BaseExtractProcess import BaseExtractProcess | ||
from .LoconExtractProcess import LoconExtractProcess | ||
from .BaseProcess import BaseProcess |
Oops, something went wrong.