Skip to content

Commit

Permalink
WIP implementing training
Browse files Browse the repository at this point in the history
  • Loading branch information
jaretburkett committed Jul 12, 2023
1 parent 47d094e commit 57f14e5
Show file tree
Hide file tree
Showing 16 changed files with 1,031 additions and 67 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@ a general understanding of python, pip, pytorch, and using virtual environments:
Linux:

```bash
git submodule update --init --recursive
pythion3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
cd requirements/sd-scripts
pip install --no-deps -e .
cd ../..
```

Windows:

```bash
git submodule update --init --recursive
pythion3 -m venv venv
venv\Scripts\activate
pip install -r requirements.txt
cd requirements/sd-scripts
pip install --no-deps -e .
cd ../..
```

## Current Tools
Expand Down
6 changes: 5 additions & 1 deletion config/examples/train.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
"base_model": "/path/to/base/model",
"training_folder": "/path/to/output/folder",
"is_v2": false,
"device": "cpu",
"device": "cuda",
"gradient_accumulation_steps": 1,
"mixed_precision": "fp16",
"logging_dir": "/path/to/tensorboard/log/folder",

"process": [
{
"type": "fine_tune"
Expand Down
6 changes: 5 additions & 1 deletion jobs/BaseJob.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
from collections import OrderedDict
from typing import List

Expand Down Expand Up @@ -48,6 +49,8 @@ def load_processes(self, process_dict: dict):
if len(self.config['process']) == 0:
raise ValueError('config file is invalid. "config.process" must be a list of processes')

module = importlib.import_module('jobs.process')

# add the processes
self.process = []
for i, process in enumerate(self.config['process']):
Expand All @@ -56,7 +59,8 @@ def load_processes(self, process_dict: dict):

# check if dict key is process type
if process['type'] in process_dict:
self.process.append(process_dict[process['type']](i, self, process))
ProcessClass = getattr(module, process_dict[process['type']])
self.process.append(ProcessClass(i, self, process))
else:
raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')

Expand Down
7 changes: 2 additions & 5 deletions jobs/ExtractJob.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List

from jobs.process import BaseExtractProcess

from jobs.process import ExtractLoconProcess
from jobs import BaseJob

process_dict = {
'locon': ExtractLoconProcess,
'locon': 'ExtractLoconProcess',
}


class ExtractJob(BaseJob):
process: List[BaseExtractProcess]

def __init__(self, config: OrderedDict):
super().__init__(config)
Expand Down
123 changes: 85 additions & 38 deletions jobs/TrainJob.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,85 @@
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from .BaseJob import BaseJob
from collections import OrderedDict
from typing import List

from jobs.process import BaseExtractProcess, TrainFineTuneProcess

process_dict = {
'fine_tine': TrainFineTuneProcess
}


class TrainJob(BaseJob):
process: List[BaseExtractProcess]

def __init__(self, config: OrderedDict):
super().__init__(config)
self.base_model_path = self.get_conf('base_model', required=True)
self.base_model = None
self.training_folder = self.get_conf('training_folder', required=True)
self.is_v2 = self.get_conf('is_v2', False)
self.device = self.get_conf('device', 'cpu')

# loads the processes from the config
self.load_processes(process_dict)

def run(self):
super().run()
# load models
print(f"Loading base model for training")
print(f" - Loading base model: {self.base_model_path}")
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)

print("")
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")

for process in self.process:
process.run()
# from jobs import BaseJob
# from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
# from collections import OrderedDict
# from typing import List
# from jobs.process import BaseExtractProcess, TrainFineTuneProcess
# import gc
# import time
# import argparse
# import itertools
# import math
# import os
# from multiprocessing import Value
#
# from tqdm import tqdm
# import torch
# from accelerate.utils import set_seed
# from accelerate import Accelerator
# import diffusers
# from diffusers import DDPMScheduler
#
# from toolkit.paths import SD_SCRIPTS_ROOT
#
# import sys
#
# sys.path.append(SD_SCRIPTS_ROOT)
#
# import library.train_util as train_util
# import library.config_util as config_util
# from library.config_util import (
# ConfigSanitizer,
# BlueprintGenerator,
# )
# import toolkit.train_tools as train_tools
# import library.custom_train_functions as custom_train_functions
# from library.custom_train_functions import (
# apply_snr_weight,
# get_weighted_text_embeddings,
# prepare_scheduler_for_custom_training,
# pyramid_noise_like,
# apply_noise_offset,
# scale_v_prediction_loss_like_noise_prediction,
# )
#
# process_dict = {
# 'fine_tine': 'TrainFineTuneProcess'
# }
#
#
# class TrainJob(BaseJob):
# process: List[BaseExtractProcess]
#
# def __init__(self, config: OrderedDict):
# super().__init__(config)
# self.base_model_path = self.get_conf('base_model', required=True)
# self.base_model = None
# self.training_folder = self.get_conf('training_folder', required=True)
# self.is_v2 = self.get_conf('is_v2', False)
# self.device = self.get_conf('device', 'cpu')
# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
# self.logging_dir = self.get_conf('logging_dir', None)
#
# # loads the processes from the config
# self.load_processes(process_dict)
#
# # setup accelerator
# self.accelerator = Accelerator(
# gradient_accumulation_steps=self.gradient_accumulation_steps,
# mixed_precision=self.mixed_precision,
# log_with=None if self.logging_dir is None else 'tensorboard',
# logging_dir=self.logging_dir,
# )
#
# def run(self):
# super().run()
# # load models
# print(f"Loading base model for training")
# print(f" - Loading base model: {self.base_model_path}")
# self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
#
# print("")
# print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
#
# for process in self.process:
# process.run()
1 change: 0 additions & 1 deletion jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .BaseJob import BaseJob
from .ExtractJob import ExtractJob
from .TrainJob import TrainJob
6 changes: 3 additions & 3 deletions jobs/process/BaseExtractProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from safetensors.torch import save_file

from jobs import ExtractJob
from jobs.process.BaseProcess import BaseProcess
from toolkit.metadata import get_meta_for_safetensors

from typing import ForwardRef


class BaseExtractProcess(BaseProcess):
job: ExtractJob
process_id: int
config: OrderedDict
output_folder: str
Expand All @@ -19,7 +19,7 @@ class BaseExtractProcess(BaseProcess):
def __init__(
self,
process_id: int,
job: ExtractJob,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
Expand Down
7 changes: 4 additions & 3 deletions jobs/process/BaseProcess.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import copy
import json
from collections import OrderedDict

from jobs import BaseJob
from typing import ForwardRef


class BaseProcess:
Expand All @@ -11,7 +10,7 @@ class BaseProcess:
def __init__(
self,
process_id: int,
job: BaseJob,
job: 'BaseJob',
config: OrderedDict
):
self.process_id = process_id
Expand Down Expand Up @@ -40,3 +39,5 @@ def run(self):
def add_meta(self, additional_meta: OrderedDict):
self.meta.update(additional_meta)


from jobs import BaseJob
4 changes: 1 addition & 3 deletions jobs/process/BaseTrainProcess.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from collections import OrderedDict
from jobs import TrainJob
from jobs.process.BaseProcess import BaseProcess


class BaseTrainProcess(BaseProcess):
job: TrainJob
process_id: int
config: OrderedDict

def __init__(
self,
process_id: int,
job: TrainJob,
job,
config: OrderedDict
):
super().__init__(process_id, job, config)
Expand Down
3 changes: 1 addition & 2 deletions jobs/process/ExtractLoconProcess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import OrderedDict
from toolkit.lycoris_utils import extract_diff
from .BaseExtractProcess import BaseExtractProcess
from .. import ExtractJob

mode_dict = {
'fixed': {
Expand All @@ -28,7 +27,7 @@


class ExtractLoconProcess(BaseExtractProcess):
def __init__(self, process_id: int, job: ExtractJob, config: OrderedDict):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
self.mode = self.get_conf('mode', 'fixed')
self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ safetensors
diffusers
transformers
lycoris_lora
flatten_json
flatten_json
accelerator
6 changes: 2 additions & 4 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import os
import sys
from collections import OrderedDict

from jobs import BaseJob

sys.path.insert(0, os.getcwd())
import argparse
from toolkit.job import get_job
Expand Down Expand Up @@ -49,6 +45,8 @@ def main():
jobs_completed = 0
jobs_failed = 0

print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")

for config_file in config_file_list:
try:
job = get_job(config_file)
Expand Down
Loading

0 comments on commit 57f14e5

Please sign in to comment.