forked from ostris/ai-toolkit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
47d094e
commit 57f14e5
Showing
16 changed files
with
1,031 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,85 @@ | ||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint | ||
from .BaseJob import BaseJob | ||
from collections import OrderedDict | ||
from typing import List | ||
|
||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess | ||
|
||
process_dict = { | ||
'fine_tine': TrainFineTuneProcess | ||
} | ||
|
||
|
||
class TrainJob(BaseJob): | ||
process: List[BaseExtractProcess] | ||
|
||
def __init__(self, config: OrderedDict): | ||
super().__init__(config) | ||
self.base_model_path = self.get_conf('base_model', required=True) | ||
self.base_model = None | ||
self.training_folder = self.get_conf('training_folder', required=True) | ||
self.is_v2 = self.get_conf('is_v2', False) | ||
self.device = self.get_conf('device', 'cpu') | ||
|
||
# loads the processes from the config | ||
self.load_processes(process_dict) | ||
|
||
def run(self): | ||
super().run() | ||
# load models | ||
print(f"Loading base model for training") | ||
print(f" - Loading base model: {self.base_model_path}") | ||
self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) | ||
|
||
print("") | ||
print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") | ||
|
||
for process in self.process: | ||
process.run() | ||
# from jobs import BaseJob | ||
# from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint | ||
# from collections import OrderedDict | ||
# from typing import List | ||
# from jobs.process import BaseExtractProcess, TrainFineTuneProcess | ||
# import gc | ||
# import time | ||
# import argparse | ||
# import itertools | ||
# import math | ||
# import os | ||
# from multiprocessing import Value | ||
# | ||
# from tqdm import tqdm | ||
# import torch | ||
# from accelerate.utils import set_seed | ||
# from accelerate import Accelerator | ||
# import diffusers | ||
# from diffusers import DDPMScheduler | ||
# | ||
# from toolkit.paths import SD_SCRIPTS_ROOT | ||
# | ||
# import sys | ||
# | ||
# sys.path.append(SD_SCRIPTS_ROOT) | ||
# | ||
# import library.train_util as train_util | ||
# import library.config_util as config_util | ||
# from library.config_util import ( | ||
# ConfigSanitizer, | ||
# BlueprintGenerator, | ||
# ) | ||
# import toolkit.train_tools as train_tools | ||
# import library.custom_train_functions as custom_train_functions | ||
# from library.custom_train_functions import ( | ||
# apply_snr_weight, | ||
# get_weighted_text_embeddings, | ||
# prepare_scheduler_for_custom_training, | ||
# pyramid_noise_like, | ||
# apply_noise_offset, | ||
# scale_v_prediction_loss_like_noise_prediction, | ||
# ) | ||
# | ||
# process_dict = { | ||
# 'fine_tine': 'TrainFineTuneProcess' | ||
# } | ||
# | ||
# | ||
# class TrainJob(BaseJob): | ||
# process: List[BaseExtractProcess] | ||
# | ||
# def __init__(self, config: OrderedDict): | ||
# super().__init__(config) | ||
# self.base_model_path = self.get_conf('base_model', required=True) | ||
# self.base_model = None | ||
# self.training_folder = self.get_conf('training_folder', required=True) | ||
# self.is_v2 = self.get_conf('is_v2', False) | ||
# self.device = self.get_conf('device', 'cpu') | ||
# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) | ||
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 | ||
# self.logging_dir = self.get_conf('logging_dir', None) | ||
# | ||
# # loads the processes from the config | ||
# self.load_processes(process_dict) | ||
# | ||
# # setup accelerator | ||
# self.accelerator = Accelerator( | ||
# gradient_accumulation_steps=self.gradient_accumulation_steps, | ||
# mixed_precision=self.mixed_precision, | ||
# log_with=None if self.logging_dir is None else 'tensorboard', | ||
# logging_dir=self.logging_dir, | ||
# ) | ||
# | ||
# def run(self): | ||
# super().run() | ||
# # load models | ||
# print(f"Loading base model for training") | ||
# print(f" - Loading base model: {self.base_model_path}") | ||
# self.base_model = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path) | ||
# | ||
# print("") | ||
# print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}") | ||
# | ||
# for process in self.process: | ||
# process.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
from .BaseJob import BaseJob | ||
from .ExtractJob import ExtractJob | ||
from .TrainJob import TrainJob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ safetensors | |
diffusers | ||
transformers | ||
lycoris_lora | ||
flatten_json | ||
flatten_json | ||
accelerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.