Skip to content

Commit

Permalink
Bisenet-FP - Improved weights for masker (deepfakes#1210)
Browse files Browse the repository at this point in the history
* Add option to load faceswap trained weights

* Switch to original weights if fs weights not present

* typofix
  • Loading branch information
torzdf authored Mar 28, 2022
1 parent 183aee3 commit 7f3e6bc
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 12 deletions.
61 changes: 51 additions & 10 deletions plugins/extract/mask/bisenet_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Architecture and Pre-Trained Model ported from PyTorch to Keras by TorzDF from
https://github.com/zllrunning/face-parsing.PyTorch
"""
import os

import numpy as np

Expand All @@ -13,28 +14,63 @@
UpSampling2D, ZeroPadding2D)

from lib.model.session import KSession
from plugins.extract._base import _get_config
from ._base import Masker, logger


class Mask(Masker):
""" Neural network to process face image into a segmentation mask of the face """
def __init__(self, **kwargs):
self._is_faceswap = self._check_weights_selection(kwargs.get("configfile"))

git_model_id = 14
model_filename = "bisnet_face_parsing_v1.h5"
model_filename = f"bisnet_face_parsing_v{'2' if self._is_faceswap else '1'}.h5"
super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs)

self.name = "BiSeNet - Face Parsing"
self.input_size = 512
self.color_format = "RGB"
self.vram = 2304
self.vram_warnings = 256
self.vram_per_batch = 64
self.batchsize = self.config["batch-size"]
self._segment_indices = self._get_segment_indices()

self._segment_indices = self._get_segment_indices()
self._storage_centering = "head" if self.config["include_hair"] else "face"
# Separate storage for face and head masks
self._storage_name = f"{self._storage_name}_{self._storage_centering}"

def _check_weights_selection(self, configfile):
""" Check which weights have been selected.
This is required for passing along the correct file name for the corresponding weights
selection, so config needs to be loaded and scanned prior to parent loading it.
Parameters
----------
configfile: str
Path to a custom configuration ``ini`` file. ``None`` to use system configfile
Returns
-------
bool
``True`` if `faceswap` trained weights have been selected. ``False`` if `original`
weights have been selected
"""
config = _get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile)
retval = config.get("weights", "faceswap").lower() == "faceswap"

# TODO Remove this check when weights moved to main code.
if retval:
_chk_dir = os.listdir(os.path.join(os.path.dirname(__file__), ".cache"))
if 'bisnet_face_parsing_v2.h5' not in _chk_dir:
logger.warning("'Faceswap' trained weights are currently Patreon timed exclusive. "
"They will be coming to the main code soon.")
logger.warning("Switching to 'Original' weights.")
retval = False

return retval

def _get_segment_indices(self):
""" Obtain the segment indices to include within the face mask area based on user
configuration settings.
Expand All @@ -46,37 +82,42 @@ def _get_segment_indices(self):
Notes
-----
Model segment indices:
'original' Model segment indices:
0: background, 1: skin, 2: left brow, 3: right brow, 4: left eye, 5: right eye, 6: glasses
7: left ear, 8: right ear, 9: earing, 10: nose, 11: mouth, 12: upper lip, 13: lower_lip,
14: neck, 15: neck ?, 16: cloth, 17: hair, 18: hat
'faceswap' Model segment indices:
0: background, 1: skin, 2: ears, 3: hair, 4: glasses
"""
retval = [1, 2, 3, 4, 5, 10, 11, 12, 13]
retval = [1] if self._is_faceswap else [1, 2, 3, 4, 5, 10, 11, 12, 13]

if self.config["include_glasses"]:
retval.append(6)
retval.append(4 if self._is_faceswap else 6)
if self.config["include_ears"]:
retval.extend([7, 8, 9])
retval.extend([2] if self._is_faceswap else [7, 8, 9])
if self.config["include_hair"]:
retval.append(17)
retval.append(3 if self._is_faceswap else 17)
logger.debug("Selected segment indices: %s", retval)
return retval

def init_model(self):
""" Initialize the BiSeNet Face Parsing model. """
lbls = 5 if self._is_faceswap else 19
self.model = BiSeNet(self.model_path,
self.config["allow_growth"],
self._exclude_gpus,
self.input_size,
19)
lbls)

placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3),
dtype="float32")
self.model.predict(placeholder)

def process_input(self, batch):
""" Compile the detected faces for prediction """
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
mean = (0.384, 0.314, 0.279) if self._is_faceswap else (0.485, 0.456, 0.406)
std = (0.324, 0.286, 0.275) if self._is_faceswap else (0.229, 0.224, 0.225)

batch["feed"] = ((np.array([feed.face[..., :3]
for feed in batch["feed_faces"]],
Expand Down
18 changes: 16 additions & 2 deletions plugins/extract/mask/bisenet_fp_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@
group="settings",
gui_radio=False,
fixed=True),
"weights": dict(
default="faceswap",
info="The trained weights to use.\n"
"\n\tfaceswap - Weights trained on wildly varied Faceswap extracted data to better "
"handle varying conditions, obstructions, glasses and multiple targets within a "
"single extracted image."
"\n\toriginal - The original weights trained on the CelebAMask-HQ dataset.",
choices=["faceswap", "original"],
datatype=str,
group="settings",
gui_radio=True,
),
"include_ears": dict(
default=False,
info="Whether to include ears within the face mask.",
Expand All @@ -77,8 +89,10 @@
),
"include_glasses": dict(
default=True,
info="Whether to include glasses within the face mask. NB: excluding glasses will mask "
"out the lenses as well as the frames.",
info="Whether to include glasses within the face mask.\n\tFor 'original' weights "
"excluding glasses will mask out the lenses as well as the frames.\n\tFor 'faceswap' "
"weights, the model has been trained to mask out lenses if eyes cannot be seen (i.e. "
"dark sunglasses) or just the frames if the eyes can be seen. ",
datatype=bool,
group="settings"
),
Expand Down

0 comments on commit 7f3e6bc

Please sign in to comment.