From 7f3e6bc3cfc09a7abfac8070bf9ca7f5ffc40604 Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Mon, 28 Mar 2022 12:58:43 +0100 Subject: [PATCH] Bisenet-FP - Improved weights for masker (#1210) * Add option to load faceswap trained weights * Switch to original weights if fs weights not present * typofix --- plugins/extract/mask/bisenet_fp.py | 61 +++++++++++++++++---- plugins/extract/mask/bisenet_fp_defaults.py | 18 +++++- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/plugins/extract/mask/bisenet_fp.py b/plugins/extract/mask/bisenet_fp.py index 0ada9fe68d..8341e42be7 100644 --- a/plugins/extract/mask/bisenet_fp.py +++ b/plugins/extract/mask/bisenet_fp.py @@ -4,6 +4,7 @@ Architecture and Pre-Trained Model ported from PyTorch to Keras by TorzDF from https://github.com/zllrunning/face-parsing.PyTorch """ +import os import numpy as np @@ -13,15 +14,19 @@ UpSampling2D, ZeroPadding2D) from lib.model.session import KSession +from plugins.extract._base import _get_config from ._base import Masker, logger class Mask(Masker): """ Neural network to process face image into a segmentation mask of the face """ def __init__(self, **kwargs): + self._is_faceswap = self._check_weights_selection(kwargs.get("configfile")) + git_model_id = 14 - model_filename = "bisnet_face_parsing_v1.h5" + model_filename = f"bisnet_face_parsing_v{'2' if self._is_faceswap else '1'}.h5" super().__init__(git_model_id=git_model_id, model_filename=model_filename, **kwargs) + self.name = "BiSeNet - Face Parsing" self.input_size = 512 self.color_format = "RGB" @@ -29,12 +34,43 @@ def __init__(self, **kwargs): self.vram_warnings = 256 self.vram_per_batch = 64 self.batchsize = self.config["batch-size"] - self._segment_indices = self._get_segment_indices() + self._segment_indices = self._get_segment_indices() self._storage_centering = "head" if self.config["include_hair"] else "face" # Separate storage for face and head masks self._storage_name = f"{self._storage_name}_{self._storage_centering}" + def _check_weights_selection(self, configfile): + """ Check which weights have been selected. + + This is required for passing along the correct file name for the corresponding weights + selection, so config needs to be loaded and scanned prior to parent loading it. + + Parameters + ---------- + configfile: str + Path to a custom configuration ``ini`` file. ``None`` to use system configfile + + Returns + ------- + bool + ``True`` if `faceswap` trained weights have been selected. ``False`` if `original` + weights have been selected + """ + config = _get_config(".".join(self.__module__.split(".")[-2:]), configfile=configfile) + retval = config.get("weights", "faceswap").lower() == "faceswap" + + # TODO Remove this check when weights moved to main code. + if retval: + _chk_dir = os.listdir(os.path.join(os.path.dirname(__file__), ".cache")) + if 'bisnet_face_parsing_v2.h5' not in _chk_dir: + logger.warning("'Faceswap' trained weights are currently Patreon timed exclusive. " + "They will be coming to the main code soon.") + logger.warning("Switching to 'Original' weights.") + retval = False + + return retval + def _get_segment_indices(self): """ Obtain the segment indices to include within the face mask area based on user configuration settings. @@ -46,28 +82,33 @@ def _get_segment_indices(self): Notes ----- - Model segment indices: + 'original' Model segment indices: 0: background, 1: skin, 2: left brow, 3: right brow, 4: left eye, 5: right eye, 6: glasses 7: left ear, 8: right ear, 9: earing, 10: nose, 11: mouth, 12: upper lip, 13: lower_lip, 14: neck, 15: neck ?, 16: cloth, 17: hair, 18: hat + + 'faceswap' Model segment indices: + 0: background, 1: skin, 2: ears, 3: hair, 4: glasses """ - retval = [1, 2, 3, 4, 5, 10, 11, 12, 13] + retval = [1] if self._is_faceswap else [1, 2, 3, 4, 5, 10, 11, 12, 13] + if self.config["include_glasses"]: - retval.append(6) + retval.append(4 if self._is_faceswap else 6) if self.config["include_ears"]: - retval.extend([7, 8, 9]) + retval.extend([2] if self._is_faceswap else [7, 8, 9]) if self.config["include_hair"]: - retval.append(17) + retval.append(3 if self._is_faceswap else 17) logger.debug("Selected segment indices: %s", retval) return retval def init_model(self): """ Initialize the BiSeNet Face Parsing model. """ + lbls = 5 if self._is_faceswap else 19 self.model = BiSeNet(self.model_path, self.config["allow_growth"], self._exclude_gpus, self.input_size, - 19) + lbls) placeholder = np.zeros((self.batchsize, self.input_size, self.input_size, 3), dtype="float32") @@ -75,8 +116,8 @@ def init_model(self): def process_input(self, batch): """ Compile the detected faces for prediction """ - mean = (0.485, 0.456, 0.406) - std = (0.229, 0.224, 0.225) + mean = (0.384, 0.314, 0.279) if self._is_faceswap else (0.485, 0.456, 0.406) + std = (0.324, 0.286, 0.275) if self._is_faceswap else (0.229, 0.224, 0.225) batch["feed"] = ((np.array([feed.face[..., :3] for feed in batch["feed_faces"]], diff --git a/plugins/extract/mask/bisenet_fp_defaults.py b/plugins/extract/mask/bisenet_fp_defaults.py index 313928473e..ab556299e5 100644 --- a/plugins/extract/mask/bisenet_fp_defaults.py +++ b/plugins/extract/mask/bisenet_fp_defaults.py @@ -63,6 +63,18 @@ group="settings", gui_radio=False, fixed=True), + "weights": dict( + default="faceswap", + info="The trained weights to use.\n" + "\n\tfaceswap - Weights trained on wildly varied Faceswap extracted data to better " + "handle varying conditions, obstructions, glasses and multiple targets within a " + "single extracted image." + "\n\toriginal - The original weights trained on the CelebAMask-HQ dataset.", + choices=["faceswap", "original"], + datatype=str, + group="settings", + gui_radio=True, + ), "include_ears": dict( default=False, info="Whether to include ears within the face mask.", @@ -77,8 +89,10 @@ ), "include_glasses": dict( default=True, - info="Whether to include glasses within the face mask. NB: excluding glasses will mask " - "out the lenses as well as the frames.", + info="Whether to include glasses within the face mask.\n\tFor 'original' weights " + "excluding glasses will mask out the lenses as well as the frames.\n\tFor 'faceswap' " + "weights, the model has been trained to mask out lenses if eyes cannot be seen (i.e. " + "dark sunglasses) or just the frames if the eyes can be seen. ", datatype=bool, group="settings" ),