Skip to content

Commit

Permalink
Setup base for training jobs. Added sd-scripts as a submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
jaretburkett committed Jul 8, 2023
1 parent 37354b0 commit 47d094e
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 18 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "repositories/sd-scripts"]
path = repositories/sd-scripts
url = https://github.com/kohya-ss/sd-scripts.git
File renamed without changes.
32 changes: 32 additions & 0 deletions config/examples/train.example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"job": "train",
"config": {
"name": "name_of_your_model",
"base_model": "/path/to/base/model",
"training_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"process": [
{
"type": "fine_tune"
}
]
},
"meta": {
"name": "[name]",
"description": "A short description of your model",
"trigger_words": [
"put",
"trigger",
"words",
"here"
],
"version": "0.1",
"creator": {
"name": "Your Name",
"email": "[email protected]",
"website": "https://yourwebsite.com"
},
"any": "All meta data above is arbitrary, it can be whatever you want."
}
}
23 changes: 23 additions & 0 deletions jobs/BaseJob.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from collections import OrderedDict
from typing import List

from jobs.process import BaseProcess


class BaseJob:
config: OrderedDict
job: str
name: str
meta: OrderedDict
process: List[BaseProcess]

def __init__(self, config: OrderedDict):
if not config:
Expand Down Expand Up @@ -37,6 +41,25 @@ def run(self):
# be sure to call super().run() first
pass

def load_processes(self, process_dict: dict):
# only call if you have processes in this job type
if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')

# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')

# check if dict key is process type
if process['type'] in process_dict:
self.process.append(process_dict[process['type']](i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')

def cleanup(self):
# if you implement this in child clas,
# be sure to call super().cleanup() LAST
Expand Down
24 changes: 8 additions & 16 deletions jobs/ExtractJob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

from jobs.process import BaseExtractProcess

from jobs.process import ExtractLoconProcess

process_dict = {
'locon': ExtractLoconProcess,
}


class ExtractJob(BaseJob):
process: List[BaseExtractProcess]
Expand All @@ -19,21 +25,8 @@ def __init__(self, config: OrderedDict):
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')

if 'process' not in self.config:
raise ValueError('config file is invalid. Missing "config.process" key')
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')

# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
if 'type' not in process:
raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
if process['type'] == 'locon':
from jobs.process import LoconExtractProcess
self.process.append(LoconExtractProcess(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
# loads the processes from the config
self.load_processes(process_dict)

def run(self):
super().run()
Expand All @@ -50,4 +43,3 @@ def run(self):

for process in self.process:
process.run()

38 changes: 38 additions & 0 deletions jobs/TrainJob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List

from jobs.process import BaseExtractProcess, TrainFineTuneProcess

process_dict = {
'fine_tine': TrainFineTuneProcess
}


class TrainJob(BaseJob):
process: List[BaseExtractProcess]

def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.training_folder = self.get_conf('training_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')

# loads the processes from the config
self.load_processes(process_dict)

def run(self):
super().run()
# load models
print(f"Loading base model for training")
print(f" - Loading base model: {self.base_model_path}")
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)

print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")

for process in self.process:
process.run()
1 change: 1 addition & 0 deletions jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .BaseJob import BaseJob
from .ExtractJob import ExtractJob
from .TrainJob import TrainJob
25 changes: 25 additions & 0 deletions jobs/process/BaseTrainProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process.BaseProcess import BaseProcess


class BaseTrainProcess(BaseProcess):
job: TrainJob
process_id: int
config: OrderedDict

def __init__(
self,
process_id: int,
job: TrainJob,
config: OrderedDict
):
super().__init__(process_id, job, config)
self.process_id = process_id
self.job = job
self.config = config

def run(self):
# implement in child class
# be sure to call super().run() first
pass
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
}


class LoconExtractProcess(BaseExtractProcess):
class ExtractLoconProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
Expand Down
13 changes: 13 additions & 0 deletions jobs/process/TrainFineTuneProcess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process import BaseTrainProcess


class TrainFineTuneProcess(BaseTrainProcess):
def __init__(self,process_id: int, job: TrainJob, config: OrderedDict):
super().__init__(process_id, job, config)

def run(self):
# implement in child class
# be sure to call super().run() first
pass
4 changes: 3 additions & 1 deletion jobs/process/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .BaseExtractProcess import BaseExtractProcess
from .LoconExtractProcess import LoconExtractProcess
from .ExtractLoconProcess import ExtractLoconProcess
from .BaseProcess import BaseProcess
from .BaseTrainProcess import BaseTrainProcess
from .TrainFineTuneProcess import TrainFineTuneProcess
1 change: 1 addition & 0 deletions repositories/sd-scripts
Submodule sd-scripts added at 0cfcb5
3 changes: 3 additions & 0 deletions toolkit/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@ def get_job(config_path) -> BaseJob:
if job == 'extract':
from jobs import ExtractJob
return ExtractJob(config)
elif job == 'train':
from jobs import TrainJob
return TrainJob(config)
else:
raise ValueError(f'Unknown job type {job}')

0 comments on commit 47d094e

Please sign in to comment.