Skip to content

Commit

Permalink
Reworked so everything is in classes for easy expansion. Single entry…
Browse files Browse the repository at this point in the history
… point for all config files now.
  • Loading branch information
jaretburkett committed Jul 8, 2023
1 parent 27df03a commit 37354b0
Show file tree
Hide file tree
Showing 16 changed files with 424 additions and 189 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ Just copy that file, into the `config` folder, and rename it to `whatever_you_wa
Then you can edit the file to your liking. and call it like so:

```bash
python3 scripts/extract_locon.py "whatever_you_want"
python3 run.py "whatever_you_want"
```

You can also put a full path to a config file, if you want to keep it somewhere else.

```bash
python3 scripts/extract_locon.py "/home/user/whatever_you_want.json"
python3 run.py "/home/user/whatever_you_want.json"
```

File name is auto generated and dumped into the `output` folder. You can put whatever meta you want in the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
{
"job": "extract",
"config": {
"name": "name_of_your_model",
"base_model": "/path/to/base/model",
"extract_model": "/path/to/model/to/extract",
"output_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"use_sparse_bias": false,
"sparsity": 0.98,
"disable_cp": false,
"process": [
{
"filename":"[name]_64_32.safetensors",
"type": "locon",
"mode": "fixed",
"linear_dim": 64,
"conv_dim": 32
"linear": 64,
"conv": 32
},
{
"output_path": "/absolute/path/for/this/output.safetensors",
"type": "locon",
"mode": "ratio",
"linear_ratio": 0.2,
"conv_ratio": 0.2
"linear": 0.2,
"conv": 0.2
},
{
"type": "locon",
"mode": "quantile",
"linear_quantile": 0.5,
"conv_quantile": 0.5
"linear": 0.5,
"conv": 0.5
}
]
},
Expand All @@ -41,6 +44,7 @@
"name": "Your Name",
"email": "[email protected]",
"website": "https://yourwebsite.com"
}
},
"any": "All meta data above is arbitrary, it can be whatever you want."
}
}
8 changes: 8 additions & 0 deletions info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from collections import OrderedDict

v = OrderedDict()
v["name"] = "ai-toolkit"
v["repo"] = "https://github.com/ostris/ai-toolkit"
v["version"] = "0.0.1"

software_meta = v
43 changes: 43 additions & 0 deletions jobs/BaseJob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections import OrderedDict


class BaseJob:
config: OrderedDict
job: str
name: str
meta: OrderedDict

def __init__(self, config: OrderedDict):
if not config:
raise ValueError('config is required')

self.config = config['config']
self.job = config['job']
self.name = self.get_conf('name', required=True)
if 'meta' in config:
self.meta = config['meta']
else:
self.meta = OrderedDict()

def get_conf(self, key, default=None, required=False):
if key in self.config:
return self.config[key]
elif required:
raise ValueError(f'config file error. Missing "config.{key}" key')
else:
return default

def run(self):
print("")
print(f"#############################################")
print(f"# Running job: {self.name}")
print(f"#############################################")
print("")
# implement in child class
# be sure to call super().run() first
pass

def cleanup(self):
# if you implement this in child clas,
# be sure to call super().cleanup() LAST
del self
53 changes: 53 additions & 0 deletions jobs/ExtractJob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List

from jobs.process import BaseExtractProcess


class ExtractJob(BaseJob):
process: List[BaseExtractProcess]

def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.extract_model_path = self.get_conf('extract_model', required=True)
self.extract_model = None
self.output_folder = self.get_conf('output_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')

if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')

# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
if process['type'] == 'locon':
from jobs.process import LoconExtractProcess
self.process.append(LoconExtractProcess(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')

def run(self):
super().run()
# load models
print(f"Loading models for extraction")
print(f" - Loading base model: {self.base_model_path}")
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)

print(f" - Loading extract model: {self.extract_model_path}")
self.extract_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)

print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")

for process in self.process:
process.run()

2 changes: 2 additions & 0 deletions jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .BaseJob import BaseJob
from .ExtractJob import ExtractJob
76 changes: 76 additions & 0 deletions jobs/process/BaseExtractProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from collections import OrderedDict

from safetensors.torch import save_file

from jobs import ExtractJob
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors


class BaseExtractProcess(BaseProcess):
job: ExtractJob
process_id: int
config: OrderedDict
output_folder: str
output_filename: str
output_path: str

def __init__(
self,
process_id: int,
job: ExtractJob,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.process_id = process_id
self.job = job
self.config = config

def run(self):
# here instead of init because child init needs to go first
self.output_path = self.get_output_path()
# implement in child class
# be sure to call super().run() first
pass

# you can override this in the child class if you want
# call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
def get_output_path(self, prefix=None, suffix=None):
config_output_path = self.get_conf('output_path', None)
config_filename = self.get_conf('filename', None)
# replace [name] with name

if config_output_path is not None:
config_output_path = config_output_path.replace('[name]', self.job.name)
return config_output_path

if config_output_path is None and config_filename is not None:
# build the output path from the output folder and filename
return os.path.join(self.job.output_folder, config_filename)

# build our own

if suffix is None:
# we will just add process it to the end of the filename if there is more than one process
# and no other suffix was given
suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''

if prefix is None:
prefix = ''

output_filename = f"{prefix}{self.output_filename}{suffix}"

return os.path.join(self.job.output_folder, output_filename)

def save(self, state_dict):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)

# save
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)

# having issues with meta
save_file(state_dict, self.output_path, save_meta)

print(f"Saved to {self.output_path}")
42 changes: 42 additions & 0 deletions jobs/process/BaseProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import copy
import json
from collections import OrderedDict

from jobs import BaseJob


class BaseProcess:
meta: OrderedDict

def __init__(
self,
process_id: int,
job: BaseJob,
config: OrderedDict
):
self.process_id = process_id
self.job = job
self.config = config
self.meta = copy.deepcopy(self.job.meta)

def get_conf(self, key, default=None, required=False, as_type=None):
if key in self.config:
value = self.config[key]
if as_type is not None:
value = as_type(value)
return value
elif required:
raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
else:
if as_type is not None:
return as_type(default)
return default

def run(self):
# implement in child class
# be sure to call super().run() first incase something is added here
pass

def add_meta(self, additional_meta: OrderedDict):
self.meta.update(additional_meta)

67 changes: 67 additions & 0 deletions jobs/process/LoconExtractProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
from .. import ExtractJob

mode_dict = {
'fixed': {
'linear': 64,
'conv': 32,
'type': int
},
'threshold': {
'linear': 0,
'conv': 0,
'type': float
},
'ratio': {
'linear': 0.5,
'conv': 0.5,
'type': float
},
'quantile': {
'linear': 0.5,
'conv': 0.5,
'type': float
}
}


class LoconExtractProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
self.sparsity = self.get_conf('sparsity', 0.98)
self.disable_cp = self.get_conf('disable_cp', False)

# set modes
if self.mode not in ['fixed', 'threshold', 'ratio', 'quantile']:
raise ValueError(f"Unknown mode: {self.mode}")
self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], mode_dict[self.mode]['type'])
self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], mode_dict[self.mode]['type'])

def run(self):
super().run()
print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")

state_dict, extract_diff_meta = extract_diff(
self.job.base_model,
self.job.extract_model,
self.mode,
self.linear_param,
self.conv_param,
self.job.device,
self.use_sparse_bias,
self.sparsity,
not self.disable_cp
)

self.add_meta(extract_diff_meta)
self.save(state_dict)

def get_output_path(self, prefix=None, suffix=None):
if suffix is None:
suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
return super().get_output_path(prefix, suffix)

3 changes: 3 additions & 0 deletions jobs/process/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .BaseExtractProcess import BaseExtractProcess
from .LoconExtractProcess import LoconExtractProcess
from .BaseProcess import BaseProcess
Loading

0 comments on commit 37354b0

Please sign in to comment.