Skip to content

Commit

Permalink
bugfix: Training
Browse files Browse the repository at this point in the history
  - Select correct channel for loss multiplier L2Reg
  • Loading branch information
torzdf committed Dec 5, 2021
1 parent 8b7b125 commit 3852b2b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions plugins/train/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,14 +1266,15 @@ def _set_loss_functions(self, output_names):
loss_func.add_loss(face_loss, mask_channel=mask_channels[0])
self._add_l2_regularization_term(loss_func, mask_channels[0])

mask_channel = 1
channel_idx = 1
for multiplier in ("eye_multiplier", "mouth_multiplier"):
mask_channel = mask_channels[channel_idx]
if self._config[multiplier] > 1:
loss_func.add_loss(face_loss,
weight=self._config[multiplier] * 1.0,
mask_channel=mask_channels[mask_channel])
mask_channel=mask_channel)
self._add_l2_regularization_term(loss_func, mask_channel)
mask_channel += 1
channel_idx += 1

logger.debug("%s: (output_name: '%s', function: %s)", name, output_name, loss_func)
self._funcs[output_name] = loss_func
Expand Down

0 comments on commit 3852b2b

Please sign in to comment.