forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Setup base for training jobs. Added sd-scripts as a submodule
- Loading branch information
1 parent
37354b0
commit 47d094e
Showing
13 changed files
with
151 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "repositories/sd-scripts"] | ||
path = repositories/sd-scripts | ||
url = https://github.com/kohya-ss/sd-scripts.git |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
{ | ||
"job": "train", | ||
"config": { | ||
"name": "name_of_your_model", | ||
"base_model": "/path/to/base/model", | ||
"training_folder": "/path/to/output/folder", | ||
"is_v2": false, | ||
"device": "cpu", | ||
"process": [ | ||
{ | ||
"type": "fine_tune" | ||
} | ||
] | ||
}, | ||
"meta": { | ||
"name": "[name]", | ||
"description": "A short description of your model", | ||
"trigger_words": [ | ||
"put", | ||
"trigger", | ||
"words", | ||
"here" | ||
], | ||
"version": "0.1", | ||
"creator": { | ||
"name": "Your Name", | ||
"email": "[email protected]", | ||
"website": "https://yourwebsite.com" | ||
}, | ||
"any": "All meta data above is arbitrary, it can be whatever you want." | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint | ||
from .BaseJob import BaseJob | ||
from collections import OrderedDict | ||
from typing import List | ||
|
||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess | ||
|
||
process_dict = { | ||
'fine_tine': TrainFineTuneProcess | ||
} | ||
|
||
|
||
class TrainJob(BaseJob): | ||
process: List[BaseExtractProcess] | ||
|
||
def __init__(self, config: OrderedDict): | ||
super().__init__(config) | ||
self.base_model_path = self.get_conf('base_model', required=True) | ||
self.base_model = None | ||
self.training_folder = self.get_conf('training_folder', required=True) | ||
self.is_v2 = self.get_conf('is_v2', False) | ||
self.device = self.get_conf('device', 'cpu') | ||
|
||
# loads the processes from the config | ||
self.load_processes(process_dict) | ||
|
||
def run(self): | ||
super().run() | ||
# load models | ||
print(f"Loading base model for training") | ||
print(f" - Loading base model: {self.base_model_path}") | ||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) | ||
|
||
print("") | ||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") | ||
|
||
for process in self.process: | ||
process.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .BaseJob import BaseJob | ||
from .ExtractJob import ExtractJob | ||
from .TrainJob import TrainJob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from collections import OrderedDict | ||
from jobs import TrainJob | ||
from jobs.process.BaseProcess import BaseProcess | ||
|
||
|
||
class BaseTrainProcess(BaseProcess): | ||
job: TrainJob | ||
process_id: int | ||
config: OrderedDict | ||
|
||
def __init__( | ||
self, | ||
process_id: int, | ||
job: TrainJob, | ||
config: OrderedDict | ||
): | ||
super().__init__(process_id, job, config) | ||
self.process_id = process_id | ||
self.job = job | ||
self.config = config | ||
|
||
def run(self): | ||
# implement in child class | ||
# be sure to call super().run() first | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from collections import OrderedDict | ||
from jobs import TrainJob | ||
from jobs.process import BaseTrainProcess | ||
|
||
|
||
class TrainFineTuneProcess(BaseTrainProcess): | ||
def __init__(self,process_id: int, job: TrainJob, config: OrderedDict): | ||
super().__init__(process_id, job, config) | ||
|
||
def run(self): | ||
# implement in child class | ||
# be sure to call super().run() first | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from .BaseExtractProcess import BaseExtractProcess | ||
from .LoconExtractProcess import LoconExtractProcess | ||
from .ExtractLoconProcess import ExtractLoconProcess | ||
from .BaseProcess import BaseProcess | ||
from .BaseTrainProcess import BaseTrainProcess | ||
from .TrainFineTuneProcess import TrainFineTuneProcess |
Submodule sd-scripts
added at
0cfcb5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters