Skip to content

Commit

Permalink
added target weight to targets
Browse files Browse the repository at this point in the history
  • Loading branch information
jaretburkett committed Jul 23, 2023
1 parent 452f2a6 commit 9a28199
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
7 changes: 4 additions & 3 deletions config/examples/train_slider.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ config:
# negative is the prompt for the negative side of the slider and works the same as positive
# it does not necessarily work the same as a negative prompt when generating images
negative: "dog"
# LoRA weight to train this target. I recommend 1.0. Just leave it, it won't work
# how you expect if you change it
multiplier: 1.0
# the loss for this target is multiplied by this number.
# if you are doing more than one target it may be good to set less important ones
# to a lower number like 0.1 so they dont outweigh the primary target
weight: 1.0

# You can put any information you want here, and it will be saved in the model.
# The below is an example, but you can put your grocery list in it if you want.
Expand Down
18 changes: 13 additions & 5 deletions jobs/process/TrainSliderProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self, **kwargs):
self.positive: str = kwargs.get('positive', None)
self.negative: str = kwargs.get('negative', None)
self.multiplier: float = kwargs.get('multiplier', 1.0)
self.weight: float = kwargs.get('weight', 1.0)


class SliderConfig:
Expand All @@ -137,6 +138,7 @@ def __init__(
height=512,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=1.0,
weight=1.0
):
self.target_class = target_class
self.positive = positive
Expand All @@ -146,6 +148,7 @@ def __init__(
self.height = height
self.action: int = action
self.multiplier = multiplier
self.weight = weight


class TrainSliderProcess(BaseTrainProcess):
Expand Down Expand Up @@ -429,7 +432,8 @@ def run(self):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier
multiplier=target.multiplier,
weight=target.weight
),
# erase inverted
EncodedPromptPair(
Expand All @@ -440,7 +444,8 @@ def run(self):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier * -1.0
multiplier=target.multiplier * -1.0,
weight=target.weight
),
# enhance standard, swap pos neg
EncodedPromptPair(
Expand All @@ -451,7 +456,8 @@ def run(self):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier
multiplier=target.multiplier,
weight=target.weight
),
# enhance inverted
EncodedPromptPair(
Expand All @@ -462,7 +468,8 @@ def run(self):
width=width,
height=height,
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier * -1.0
multiplier=target.multiplier * -1.0,
weight=target.weight
),
]

Expand Down Expand Up @@ -494,6 +501,7 @@ def run(self):
neutral = prompt_pair.neutral
negative = prompt_pair.negative
positive = prompt_pair.positive
weight = prompt_pair.weight

# set network multiplier
self.network.multiplier = prompt_pair.multiplier
Expand Down Expand Up @@ -621,7 +629,7 @@ def run(self):
loss = loss_function(
target_latents,
offset_neutral,
)
) * weight

loss_float = loss.item()
if self.train_config.optimizer.startswith('dadaptation'):
Expand Down

0 comments on commit 9a28199

Please sign in to comment.