Skip to content

Commit

Permalink
Add Mouth and Eye Priority to Loss options (deepfakes#1054)
Browse files Browse the repository at this point in the history
* Priority Training for Mouth and Eyes - Tensorflow

* Use chosen loss function for area multipliers

* loss multipliers for AMD

* Fix mask multipliers for plaid and roll PenalizedMaskLoss into LossWrapper

* losses_tf: roll PenalizedMaskLoss into LossWrapper
  • Loading branch information
torzdf authored Aug 27, 2020
1 parent 1363fa8 commit 3433928
Show file tree
Hide file tree
Showing 7 changed files with 521 additions and 263 deletions.
117 changes: 116 additions & 1 deletion lib/faces_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,50 @@ def add_mask(self, name, mask, affine_matrix, interpolator, storage_size=128):
fsmask.add(mask, affine_matrix, interpolator)
self.mask[name] = fsmask

def get_landmark_mask(self, size, area, aligned=True, dilation=0, blur_kernel=0, as_zip=False):
""" Obtain a single channel mask based on the face's landmark points.
Parameters
----------
size: int or tuple
The size of the aligned mask to retrieve. Should be an `int` if an aligned face is
being requested, or a ('height', 'width') shape tuple if a full frame is being
requested
area: ["mouth", "eyes"]
The type of mask to obtain. `face` is a full face mask the others are masks for those
specific areas
aligned: bool
``True`` if the returned mask should be for an aligned face. ``False`` if a full frame
mask should be returned
dilation: int, optional
The amount of dilation to apply to the mask. `0` for none. Default: `0`
blur_kernel: int, optional
The kernel size for applying gaussian blur to apply to the mask. `0` for none.
Default: `0`
as_zip: bool, optional
``True`` if the mask should be returned zipped otherwise ``False``
Returns
-------
:class:`numpy.ndarray` or zipped array
The mask as a single channel image of the given :attr:`size` dimension. If
:attr:`as_zip` is ``True`` then the :class:`numpy.ndarray` will be contained within a
zipped container
"""
# TODO Face mask generation from landmarks
logger.trace("size: %s, area: %s, aligned: %s, dilation: %s, blur_kernel: %s, as_zip: %s",
size, area, aligned, dilation, blur_kernel, as_zip)
areas = dict(mouth=[slice(48, 60)],
eyes=[slice(36, 42), slice(42, 48)])
if aligned and self.aligned.get("size") != size:
self.load_aligned(None, size=size, force=True)
size = (size, size) if aligned else size
landmarks = self.aligned_landmarks if aligned else self.landmarks_xy
points = [landmarks[zone] for zone in areas[area]]
mask = _LandmarksMask(size, points, dilation=dilation, blur_kernel=blur_kernel)
retval = mask.get(as_zip=as_zip)
return retval

def to_alignment(self):
""" Return the detected face formatted for an alignments file
Expand Down Expand Up @@ -511,6 +555,77 @@ def reference_interpolators(self):
return get_matrix_scaling(self.reference_matrix)


class _LandmarksMask(): # pylint:disable=too-few-public-methods
""" Create a single channel mask from aligned landmark points.
size: tuple
The (height, width) shape tuple that the mask should be returned as
points: list
A list of landmark points that correspond to the given shape tuple to create
the mask. Each item in the list should be a :class:`numpy.ndarray` that a filled
convex polygon will be created from
dilation: int, optional
The amount of dilation to apply to the mask. `0` for none. Default: `0`
blur_kernel: int, optional
The kernel size for applying gaussian blur to apply to the mask. `0` for none. Default: `0`
"""
def __init__(self, size, points, dilation=0, blur_kernel=0):
logger.trace("Initializing: %s: (size: %s, points: %s, dilation: %s, blur_kernel: %s)",
size, points, dilation, blur_kernel)
self._size = size
self._points = points
self._dilation = dilation
self._blur_kernel = blur_kernel
self._mask = None
logger.trace("Initialized: %s", self.__class__.__name__)

def get(self, as_zip=False):
""" Obtain the mask.
Parameters
----------
as_zip: bool, optional
``True`` if the mask should be returned zipped otherwise ``False``
Returns
-------
:class:`numpy.ndarray` or zipped array
The mask as a single channel image of the given :attr:`size` dimension. If
:attr:`as_zip` is ``True`` then the :class:`numpy.ndarray` will be contained within a
zipped container
"""
if not np.any(self._mask):
self._generate_mask()
retval = compress(self._mask) if as_zip else self._mask
logger.trace("as_zip: %s, retval type: %s", as_zip, type(retval))
return retval

def _generate_mask(self):
""" Generate the mask.
Creates the mask applying any requested dilation and blurring and assigns to
:attr:`_mask`
Returns
-------
:class:`numpy.ndarray`
The mask as a single channel image of the given :attr:`size` dimension.
"""
mask = np.zeros((self._size) + (1, ), dtype="float32")
for landmarks in self._points:
lms = np.rint(landmarks).astype("int")
cv2.fillConvexPoly(mask, cv2.convexHull(lms), 1.0, lineType=cv2.LINE_AA)
if self._dilation != 0:
mask = cv2.dilate(mask,
cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(self._dilation, self._dilation)),
iterations=1)
if self._blur_kernel != 0:
mask = BlurMask("gaussian", mask, self._blur_kernel).blurred
logger.trace("mask: (shape: %s, dtype: %s)", mask.shape, mask.dtype)
self._mask = (mask * 255.0).astype("uint8")


class Mask():
""" Face Mask information and convenience methods
Expand Down Expand Up @@ -741,7 +856,7 @@ def _attr_name(dict_key):
return retval


class BlurMask():
class BlurMask(): # pylint:disable=too-few-public-methods
""" Factory class to return the correct blur object for requested blur type.
Works for square images only. Currently supports Gaussian and Normalized Box Filters.
Expand Down
164 changes: 68 additions & 96 deletions lib/model/losses_plaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __call__(self, y_true, y_pred):
denom = (K.square(u_true) + K.square(u_pred) + self.c_1) * (
var_pred + var_true + self.c_2)
ssim /= denom # no need for clipping, c_1 + c_2 make the denorm non-zero
return K.mean((1.0 - ssim) / 2.0)
return (1.0 - ssim) / 2.0

@staticmethod
def _preprocess_padding(padding):
Expand Down Expand Up @@ -199,66 +199,6 @@ def extract_image_patches(self, input_tensor, k_sizes, s_sizes,
return patches


class PenalizedLoss(): # pylint:disable=too-few-public-methods
""" Penalized Loss function.
Applies the given loss function just to the masked area of the image.
Parameters
----------
loss_func: function
The actual loss function to use
mask_prop: float, optional
The amount of mask propagation. Default: `1.0`
"""
def __init__(self, loss_func, mask_prop=1.0):
self._loss_func = loss_func
self._mask_prop = mask_prop

def __call__(self, y_true, y_pred):
""" Apply the loss function to the masked area of the image.
Parameters
----------
y_true: tensor or variable
The ground truth value. This should contain the mask in the 4th channel that will be
split off for penalizing.
y_pred: tensor or variable
The predicted value
Returns
-------
tensor
The Loss value
"""
mask = self._prepare_mask(K.expand_dims(y_true[..., -1], axis=-1))
y_true = y_true[..., :-1]
n_true = y_true * mask
n_pred = y_pred * mask
if isinstance(self._loss_func, DSSIMObjective):
# Extract Image Patches in SSIM requires that y_pred be of a known shape, so
# specifically reshape the tensor.
n_pred = K.reshape(n_pred, K.int_shape(y_pred))
return self._loss_func(n_true, n_pred)

def _prepare_mask(self, mask):
""" Prepare the masks for calculating loss
Parameters
----------
mask: :class:`numpy.ndarray`
The masks for the current batch
Returns
-------
tensor
The prepared mask for applying to loss
"""
mask_as_k_inv_prop = 1 - self._mask_prop
mask = (mask * self._mask_prop) + mask_as_k_inv_prop
return mask


class GeneralizedLoss(): # pylint:disable=too-few-public-methods
""" Generalized function used to return a large variety of mathematical loss functions.
Expand Down Expand Up @@ -564,46 +504,33 @@ def _scharr_edges(cls, image, magnitude):

class LossWrapper(): # pylint:disable=too-few-public-methods
""" A wrapper class for multiple keras losses to enable multiple weighted loss functions on a
single output.
Parameters
----------
loss_functions: list
A list of either a tuple of (:class:`keras.losses.Loss`, scalar weight) or just a
:class:`keras.losses.Loss` function. If just the loss function is passed, then the weight
is assumed to be 1.0 """
def __init__(self, loss_functions):
logger.debug("Initializing: %s: (loss_functions: %s)",
self.__class__.__name__, loss_functions)
single output and masking.
"""
def __init__(self):
logger.debug("Initializing: %s", self.__class__.__name__)
self._loss_functions = []
self._loss_weights = []
self._compile_losses(loss_functions)
self._mask_channels = []
logger.debug("Initialized: %s", self.__class__.__name__)

def _compile_losses(self, loss_functions):
""" Splits the given loss_functions into the corresponding :attr:`_loss_functions' and
:attr:`_loss_weights' lists.
Loss functions are compiled into :class:`keras.compile_utils.LossesContainer` objects
def add_loss(self, function, weight=1.0, mask_channel=-1):
""" Add the given loss function with the given weight to the loss function chain.
Parameters
----------
loss_functions: list
A list of either a tuple of (:class:`keras.losses.Loss`, scalar weight) or just a
:class:`keras.losses.Loss` function. If just the loss function is passed, then the
weight is assumed to be 1.0 """
for loss_func in loss_functions:
if isinstance(loss_func, tuple):
assert len(loss_func) == 2, "Tuple loss functions should contain 2 items"
assert isinstance(loss_func[1], float), "weight should be a float"
func, weight = loss_func
else:
func = loss_func
weight = 1.0
self._loss_functions.append(func)
self._loss_weights.append(weight)
logger.debug("Compiled losses: (functions: %s, weights: %s",
self._loss_functions, self._loss_weights)
function: :class:`keras.losses.Loss`
The loss function to add to the loss chain
weight: float, optional
The weighting to apply to the loss function. Default: `1.0`
mask_channel: int, optional
The channel in the `y_true` image that the mask exists in. Set to `-1` if there is no
mask for the given loss function. Default: `-1`
"""
logger.debug("Adding loss: (function: %s, weight: %s, mask_channel: %s)",
function, weight, mask_channel)
self._loss_functions.append(function)
self._loss_weights.append(weight)
self._mask_channels.append(mask_channel)

def __call__(self, y_true, y_pred):
""" Call the sub loss functions for the loss wrapper.
Expand All @@ -623,6 +550,51 @@ def __call__(self, y_true, y_pred):
The final loss value
"""
loss = 0.0
for func, weight in zip(self._loss_functions, self._loss_weights):
loss += (K.mean(func(y_true, y_pred)) * weight)
for func, weight, mask_channel in zip(self._loss_functions,
self._loss_weights,
self._mask_channels):
logger.debug("Processing loss function: (func: %s, weight: %s, mask_channel: %s)",
func, weight, mask_channel)
n_true, n_pred = self._apply_mask(y_true, y_pred, mask_channel)
if isinstance(func, DSSIMObjective):
# Extract Image Patches in SSIM requires that y_pred be of a known shape, so
# specifically reshape the tensor.
n_pred = K.reshape(n_pred, K.int_shape(y_pred))
this_loss = func(n_true, n_pred)
loss_dims = K.ndim(this_loss)
loss += (K.mean(this_loss, axis=list(range(1, loss_dims))) * weight)
return loss

@classmethod
def _apply_mask(cls, y_true, y_pred, mask_channel, mask_prop=1.0):
""" Apply the mask to the input y_true and y_pred. If a mask is not required then
return the unmasked inputs.
Parameters
----------
y_true: tensor or variable
The ground truth value
y_pred: tensor or variable
The predicted value
mask_channel: int
The channel within y_true that the required mask resides in
mask_prop: float, optional
The amount of mask propagation. Default: `1.0`
Returns
-------
tuple
(n_true, n_pred): The ground truth and predicted value tensors with the mask applied
"""
if mask_channel == -1:
logger.debug("No mask to apply")
return y_true[..., :3], y_pred[..., :3]

logger.debug("Applying mask from channel %s", mask_channel)
mask = K.expand_dims(y_true[..., mask_channel], axis=-1)
mask_as_k_inv_prop = 1 - mask_prop
mask = (mask * mask_prop) + mask_as_k_inv_prop

n_true = y_true[..., :3] * mask
n_pred = y_pred * mask
return n_true, n_pred
Loading

0 comments on commit 3433928

Please sign in to comment.