diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index 3d37eefc..34c12e5f 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -158,9 +158,10 @@ config: # negative is the prompt for the negative side of the slider and works the same as positive # it does not necessarily work the same as a negative prompt when generating images negative: "dog" - # LoRA weight to train this target. I recommend 1.0. Just leave it, it won't work - # how you expect if you change it - multiplier: 1.0 + # the loss for this target is multiplied by this number. + # if you are doing more than one target it may be good to set less important ones + # to a lower number like 0.1 so they dont outweigh the primary target + weight: 1.0 # You can put any information you want here, and it will be saved in the model. # The below is an example, but you can put your grocery list in it if you want. diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index c7f04cc1..3453d7d9 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -116,6 +116,7 @@ def __init__(self, **kwargs): self.positive: str = kwargs.get('positive', None) self.negative: str = kwargs.get('negative', None) self.multiplier: float = kwargs.get('multiplier', 1.0) + self.weight: float = kwargs.get('weight', 1.0) class SliderConfig: @@ -137,6 +138,7 @@ def __init__( height=512, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, multiplier=1.0, + weight=1.0 ): self.target_class = target_class self.positive = positive @@ -146,6 +148,7 @@ def __init__( self.height = height self.action: int = action self.multiplier = multiplier + self.weight = weight class TrainSliderProcess(BaseTrainProcess): @@ -429,7 +432,8 @@ def run(self): width=width, height=height, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier + multiplier=target.multiplier, + weight=target.weight ), # erase inverted EncodedPromptPair( @@ -440,7 +444,8 @@ def run(self): width=width, height=height, action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, - multiplier=target.multiplier * -1.0 + multiplier=target.multiplier * -1.0, + weight=target.weight ), # enhance standard, swap pos neg EncodedPromptPair( @@ -451,7 +456,8 @@ def run(self): width=width, height=height, action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier + multiplier=target.multiplier, + weight=target.weight ), # enhance inverted EncodedPromptPair( @@ -462,7 +468,8 @@ def run(self): width=width, height=height, action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, - multiplier=target.multiplier * -1.0 + multiplier=target.multiplier * -1.0, + weight=target.weight ), ] @@ -494,6 +501,7 @@ def run(self): neutral = prompt_pair.neutral negative = prompt_pair.negative positive = prompt_pair.positive + weight = prompt_pair.weight # set network multiplier self.network.multiplier = prompt_pair.multiplier @@ -621,7 +629,7 @@ def run(self): loss = loss_function( target_latents, offset_neutral, - ) + ) * weight loss_float = loss.item() if self.train_config.optimizer.startswith('dadaptation'):