diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index a3b168ab1e..a3b1c7d447 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,5 +1,7 @@ **Note: Please only report bugs in this repository. Just because you are getting an error message does not automatically mean you have discovered a bug. If you don't have a lot of experience with this type of project, or if you need for setup help and other issues in using the faceswap tool, please refer to the [faceswap-playground](https://github.com/deepfakes/faceswap-playground/issues) instead. The faceswap-playground is also an excellent place to ask questions and submit feedback.** +**Please always attach your generated crash_report.log to any bug report** + ## Expected behavior *Describe, in some detail, what you are trying to do and what the output is that you expect from the program.* diff --git a/.gitignore b/.gitignore index 9f1a273024..70913e7266 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,9 @@ !Dockerfile* !requirements* !.cache -!lib -!lib/face_alignment +!config/ +!lib/ +!lib/* !lib/gui !lib/gui/.cache/preview !lib/gui/.cache/icons @@ -21,9 +22,9 @@ !plugins/ !plugins/* !plugins/extract/* -!plugins/model/* +!plugins/train/* !tools !tools/lib* - +*.ini *.pyc __pycache__/ diff --git a/INSTALL.md b/INSTALL.md index d562914afb..bb9240bee5 100755 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,37 +1,37 @@ # Installing Faceswap - [Installing Faceswap](#installing-faceswap) - [Prerequisites](#prerequisites) - - [Hardware Requirements](#hardware-requirements) - - [Supported operating systems](#supported-operating-systems) + - [Hardware Requirements](#hardware-requirements) + - [Supported operating systems](#supported-operating-systems) - [Important before you proceed](#important-before-you-proceed) - [General Install Guide](#general-install-guide) - - [Installing dependencies](#installing-dependencies) - - [Getting the faceswap code](#getting-the-faceswap-code) - - [Setup](#setup) - - [About some of the options](#about-some-of-the-options) - - [Run the project](#run-the-project) - - [Notes](#notes) + - [Installing dependencies](#installing-dependencies) + - [Getting the faceswap code](#getting-the-faceswap-code) + - [Setup](#setup) + - [About some of the options](#about-some-of-the-options) + - [Run the project](#run-the-project) + - [Notes](#notes) - [Windows Install Guide](#windows-install-guide) - - [Prerequisites](#prerequisites-1) - - [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015) - - [Cuda](#cuda) - - [cuDNN](#cudnn) - - [CMake](#cmake) - - [Anaconda](#anaconda) - - [Git](#git) - - [Setup](#setup-1) - - [Anaconda](#anaconda-1) - - [Set up a virtual environment](#set-up-a-virtual-environment) - - [Entering your virtual environment](#entering-your-virtual-environment) - - [Faceswap](#faceswap) - - [Easy install](#easy-install) - - [Manual install](#manual-install) - - [Running Faceswap](#running-faceswap) - - [Create a desktop shortcut](#create-a-desktop-shortcut) - - [Updating faceswap](#updating-faceswap) - - [Dlib](#dlib) - - [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support) - - [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support) + - [Prerequisites](#prerequisites-1) + - [Microsoft Visual Studio 2015](#microsoft-visual-studio-2015) + - [Cuda](#cuda) + - [cuDNN](#cudnn) + - [CMake](#cmake) + - [Anaconda](#anaconda) + - [Git](#git) + - [Setup](#setup-1) + - [Anaconda](#anaconda-1) + - [Set up a virtual environment](#set-up-a-virtual-environment) + - [Entering your virtual environment](#entering-your-virtual-environment) + - [Faceswap](#faceswap) + - [Easy install](#easy-install) + - [Manual install](#manual-install) + - [Running Faceswap](#running-faceswap) + - [Create a desktop shortcut](#create-a-desktop-shortcut) + - [Updating faceswap](#updating-faceswap) + - [Dlib](#dlib) + - [Build Latest Dlib with GPU Support](#build-latest-dlib-with-gpu-support) + - [Easy install of Dlib without GPU Support](#easy-install-of-dlib-without-gpu-support) # Prerequisites Machine learning essentially involves a ton of trial and error. You're letting a program try millions of different settings to land on an algorithm that sort of does what you want it to do. This process is really really slow unless you have the hardware required to speed this up. diff --git a/USAGE.md b/USAGE.md index fefa64d34a..852e35479c 100755 --- a/USAGE.md +++ b/USAGE.md @@ -34,6 +34,8 @@ You can see the full list of arguments for extracting via help flag. i.e. python faceswap.py extract -h ``` +Some of the plugins have configurable options. You can find the config options in: `\plugins\extract\config.ini`. Extract needs to have been run at least once to generate the config file + ## TRAIN The training process will take the longest, especially on CPU. We specify the folders where the two faces are, and where we will save our training model. It will start hammering the training data once you run the command. I personally really like to go by the preview and quit the processing once I'm happy with the results. @@ -51,6 +53,9 @@ You can see the full list of arguments for training via help flag. i.e. python faceswap.py train -h ``` +Some of the plugins have configurable options. You can find the config options in: `\plugins\traom\config.ini`. Train needs to have been run at least once to generate the config file + + ## CONVERT Now that we're happy with our trained model, we can convert our video. How does it work? Similarly to the extraction script, actually! The conversion script basically detects a face in a picture using the same algorithm, quickly crops the image to the right size, runs our bot on this cropped image of the face it has found, and then (crudely) pastes the processed face back into the picture. @@ -86,7 +91,7 @@ python tools.py effmpeg -h ``` ## Extracting video frames with FFMPEG -Alternatively you can split a video into seperate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to seperate frames. +Alternatively you can split a video into separate frames using [ffmpeg](https://www.ffmpeg.org) for instance. Below is an example command to process a video to separate frames. ```bash ffmpeg -i /path/to/my/video.mp4 /path/to/output/video-frame-%d.png diff --git a/plugins/model/__init__.py b/config/.keep similarity index 100% rename from plugins/model/__init__.py rename to config/.keep diff --git a/lib/PixelShuffler.py b/lib/PixelShuffler.py deleted file mode 100644 index f4961fefa5..0000000000 --- a/lib/PixelShuffler.py +++ /dev/null @@ -1,88 +0,0 @@ -# PixelShuffler layer for Keras -# by t-ae -# https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981 - -from keras.utils import conv_utils -from keras.engine.topology import Layer -import keras.backend as K - - -class PixelShuffler(Layer): - def __init__(self, size=(2, 2), data_format=None, **kwargs): - super(PixelShuffler, self).__init__(**kwargs) - self.data_format = K.normalize_data_format(data_format) - self.size = conv_utils.normalize_tuple(size, 2, 'size') - - def call(self, inputs): - - input_shape = K.int_shape(inputs) - if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + - str(4) + - '; Received input shape:', str(input_shape)) - - if self.data_format == 'channels_first': - batch_size, c, h, w = input_shape - if batch_size is None: - batch_size = -1 - rh, rw = self.size - oh, ow = h * rh, w * rw - oc = c // (rh * rw) - - out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) - out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) - out = K.reshape(out, (batch_size, oc, oh, ow)) - return out - - elif self.data_format == 'channels_last': - batch_size, h, w, c = input_shape - if batch_size is None: - batch_size = -1 - rh, rw = self.size - oh, ow = h * rh, w * rw - oc = c // (rh * rw) - - out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) - out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) - out = K.reshape(out, (batch_size, oh, ow, oc)) - return out - - def compute_output_shape(self, input_shape): - - if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + - str(4) + - '; Received input shape:', str(input_shape)) - - if self.data_format == 'channels_first': - height = input_shape[2] * self.size[0] if input_shape[2] is not None else None - width = input_shape[3] * self.size[1] if input_shape[3] is not None else None - channels = input_shape[1] // self.size[0] // self.size[1] - - if channels * self.size[0] * self.size[1] != input_shape[1]: - raise ValueError('channels of input and size are incompatible') - - return (input_shape[0], - channels, - height, - width) - - elif self.data_format == 'channels_last': - height = input_shape[1] * self.size[0] if input_shape[1] is not None else None - width = input_shape[2] * self.size[1] if input_shape[2] is not None else None - channels = input_shape[3] // self.size[0] // self.size[1] - - if channels * self.size[0] * self.size[1] != input_shape[3]: - raise ValueError('channels of input and size are incompatible') - - return (input_shape[0], - height, - width, - channels) - - def get_config(self): - config = {'size': self.size, - 'data_format': self.data_format} - base_config = super(PixelShuffler, self).get_config() - - return dict(list(base_config.items()) + list(config.items())) diff --git a/lib/aligner.py b/lib/aligner.py index 74e1b0d4bf..4770f908eb 100644 --- a/lib/aligner.py +++ b/lib/aligner.py @@ -12,29 +12,6 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name -MEAN_FACE_X = np.array([ - 0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, - 0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, - 0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122, - 0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132, - 0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127, - .551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532, - 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364, - 0.490127, 0.42689]) - -MEAN_FACE_Y = np.array([ - 0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, - 0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, - 0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, - 0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758, - 0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578, - 0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192, - 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182, - 0.831803, 0.824182]) - -LANDMARKS_2D = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1) - - class Extract(): """ Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs """ @@ -42,8 +19,9 @@ class Extract(): def extract(self, image, face, size, align_eyes): """ Extract a face from an image """ logger.trace("size: %s. align_eyes: %s", size, align_eyes) + padding = int(size * 0.1875) alignment = get_align_mat(face, size, align_eyes) - extracted = self.transform(image, alignment, size, 48) + extracted = self.transform(image, alignment, size, padding) logger.trace("Returning face and alignment matrix: (alignment_matrix: %s)", alignment) return extracted, alignment @@ -60,8 +38,9 @@ def transform(self, image, mat, size, padding=0): """ Transform Image """ logger.trace("matrix: %s, size: %s. padding: %s", mat, size, padding) matrix = self.transform_matrix(mat, size, padding) + interpolators = get_matrix_scaling(matrix) return cv2.warpAffine( # pylint: disable=no-member - image, matrix, (size, size)) + image, matrix, (size, size), flags=interpolators[0]) def transform_points(self, points, mat, size, padding=0): """ Transform points along matrix """ @@ -144,12 +123,23 @@ def get_feature_mask(aligned_landmarks_68, size, return mask +def get_matrix_scaling(mat): + """ Get the correct interpolator """ + x_scale = np.sqrt(mat[0, 0] * mat[0, 0] + mat[0, 1] * mat[0, 1]) + y_scale = (mat[0, 0] * mat[1, 1] - mat[0, 1] * mat[1, 0]) / x_scale + avg_scale = (x_scale + y_scale) * 0.5 + if avg_scale >= 1.0: + interpolators = cv2.INTER_CUBIC, cv2.INTER_AREA # pylint: disable=no-member + else: + interpolators = cv2.INTER_AREA, cv2.INTER_CUBIC # pylint: disable=no-member + logger.trace("interpolator: %s, inverse interpolator: %s", interpolators[0], interpolators[1]) + return interpolators + + def get_align_mat(face, size, should_align_eyes): """ Return the alignment Matrix """ logger.trace("size: %s, should_align_eyes: %s", size, should_align_eyes) - mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]), - LANDMARKS_2D, - True)[0:2] + mat_umeyama = umeyama(np.array(face.landmarks_as_xy[17:]), True)[0:2] if should_align_eyes is False: return mat_umeyama diff --git a/lib/alignments.py b/lib/alignments.py index ad25dfce81..2780631529 100644 --- a/lib/alignments.py +++ b/lib/alignments.py @@ -270,33 +270,6 @@ def yield_original_index_reverse(image_alignments, number_alignments): # << LEGACY FUNCTIONS >> # - # < Original Frame Dimensions > # - # For dfaker and convert-adjust the original dimensions of a frame are - # required to calculate the transposed landmarks. As transposed landmarks - # will change on face size, we store original frame dimensions - # These were not previously required, so this adds the dimensions - # to the landmarks file - - def get_legacy_no_dims(self): - """ Return a list of frames that do not contain the original frame - height and width attributes """ - logger.debug("Getting alignments without frame_dims") - keys = list() - for key, val in self.data.items(): - for alignment in val: - if "frame_dims" not in alignment.keys(): - keys.append(key) - break - logger.debug("Got alignments without frame_dims: %s", len(keys)) - return keys - - def add_dimensions(self, frame_name, dimensions): - """ Backward compatability fix. Add frame dimensions - to alignments """ - logger.trace("Adding dimensions: (frame: '%s', dimensions: %s)", frame_name, dimensions) - for face in self.get_faces_in_frame(frame_name): - face["frame_dims"] = dimensions - # < Rotation > # # The old rotation method would rotate the image to find a face, then # store the rotated landmarks along with a rotation value to tell the @@ -319,20 +292,20 @@ def get_legacy_rotation(self): logger.debug("Got alignments containing legacy rotations: %s", len(keys)) return keys - def rotate_existing_landmarks(self, frame_name): + def rotate_existing_landmarks(self, frame_name, frame): """ Backwards compatability fix. Rotates the landmarks to their correct position and deletes r - NB: The original frame dimensions must be passed in otherwise + NB: The original frame must be passed in otherwise the transformation cannot be performed """ logger.trace("Rotating existing landmarks for frame: '%s'", frame_name) + dims = frame.shape[:2] for face in self.get_faces_in_frame(frame_name): angle = face.get("r", 0) if not angle: logger.trace("Landmarks do not require rotation: '%s'", frame_name) return logger.trace("Rotating landmarks: (frame: '%s', angle: %s)", frame_name, angle) - dims = face["frame_dims"] r_mat = self.get_original_rotation_matrix(dims, angle) rotate_landmarks(face, r_mat) del face["r"] diff --git a/lib/cli.py b/lib/cli.py index 777c558510..4e775d48b4 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -103,9 +103,41 @@ def execute_script(self, arguments): safe_shutdown() -class FullPaths(argparse.Action): +class Slider(argparse.Action): # pylint: disable=too-few-public-methods + """ Adds support for the GUI slider + + An additional option 'min_max' must be provided containing tuple of min and max accepted + values. + + 'rounding' sets the decimal places for floats or the step interval for ints. + """ + def __init__(self, option_strings, dest, nargs=None, min_max=None, rounding=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super().__init__(option_strings, dest, **kwargs) + self.min_max = min_max + self.rounding = rounding + + def _get_kwargs(self): + names = ["option_strings", + "dest", + "nargs", + "const", + "default", + "type", + "choices", + "help", + "metavar", + "min_max", # Tuple containing min and max values of scale + "rounding"] # Decimal places to round floats to or step interval for ints + return [(name, getattr(self, name)) for name in names] + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class FullPaths(argparse.Action): # pylint: disable=too-few-public-methods """ Expand user- and relative-paths """ - # pylint: disable=too-few-public-methods def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, os.path.abspath( os.path.expanduser(values))) @@ -124,26 +156,23 @@ class FileFullPaths(FullPaths): see lib/gui/utils.py FileHandler for current GUI filetypes """ # pylint: disable=too-few-public-methods - def __init__(self, option_strings, dest, nargs=None, filetypes=None, - **kwargs): + def __init__(self, option_strings, dest, nargs=None, filetypes=None, **kwargs): super(FileFullPaths, self).__init__(option_strings, dest, **kwargs) if nargs is not None: raise ValueError("nargs not allowed") self.filetypes = filetypes def _get_kwargs(self): - names = [ - "option_strings", - "dest", - "nargs", - "const", - "default", - "type", - "choices", - "help", - "metavar", - "filetypes" - ] + names = ["option_strings", + "dest", + "nargs", + "const", + "default", + "type", + "choices", + "help", + "metavar", + "filetypes"] return [(name, getattr(self, name)) for name in names] @@ -185,19 +214,17 @@ def __init__(self, option_strings, dest, nargs=None, filetypes=None, self.filetypes = filetypes def _get_kwargs(self): - names = [ - "option_strings", - "dest", - "nargs", - "const", - "default", - "type", - "choices", - "help", - "metavar", - "filetypes", - "action_option" - ] + names = ["option_strings", + "dest", + "nargs", + "const", + "default", + "type", + "choices", + "help", + "metavar", + "filetypes", + "action_option"] return [(name, getattr(self, name)) for name in names] @@ -282,6 +309,13 @@ def get_global_arguments(): "help": "Path to store the logfile. Leave blank to store in the " "faceswap folder", "default": None}) + # This is a hidden argument to indicate that the GUI is being used, + # so the preview window should be redirected Accordingly + global_args.append({"opts": ("-gui", "--gui"), + "action": "store_true", + "dest": "redirect_gui", + "default": False, + "help": argparse.SUPPRESS}) return global_args @staticmethod @@ -342,11 +376,14 @@ def get_argument_list(): "dest": "alignments_path", "help": "Optional path to an alignments file."}) argument_list.append({"opts": ("-l", "--ref_threshold"), + "action": Slider, + "min_max": (0.01, 0.99), + "rounding": 2, "type": float, "dest": "ref_threshold", "default": 0.6, - "help": "Threshold for positive face " - "recognition"}) + "help": "Threshold for positive face recognition. For use with " + "nfilter or filter. Lower values are stricter."}) argument_list.append({"opts": ("-n", "--nfilter"), "type": str, "dest": "nfilter", @@ -389,7 +426,7 @@ def get_optional_arguments(): "fallback."}) argument_list.append({ "opts": ("-D", "--detector"), - "type": str, + "type": str.lower, "choices": PluginLoader.get_available_extractors( "detect"), "default": "mtcnn", @@ -404,7 +441,7 @@ def get_optional_arguments(): "\n\talignment to dlib"}) argument_list.append({ "opts": ("-A", "--aligner"), - "type": str, + "type": str.lower, "choices": PluginLoader.get_available_extractors( "align"), "default": "fan", @@ -413,38 +450,6 @@ def get_optional_arguments(): "\n\tresource intensive, but less accurate." "\n'fan': Face Alignment Network. Best aligner." "\n\tGPU heavy."}) - argument_list.append({"opts": ("-mtms", "--mtcnn-minsize"), - "type": int, - "dest": "mtcnn_minsize", - "default": 20, - "help": "The minimum size of a face to be " - "accepted. Lower values use " - "significantly more VRAM. Minimum " - "value is 10. Default is 20 " - "(MTCNN detector only)"}) - argument_list.append({"opts": ("-mtth", "--mtcnn-threshold"), - "nargs": "+", - "type": str, - "dest": "mtcnn_threshold", - "default": ["0.6", "0.7", "0.7"], - "help": "R|Three step threshold for face " - "detection. Should be\nthree decimal " - "numbers each less than 1. Eg:\n" - "'--mtcnn-threshold 0.6 0.7 0.7'.\n" - "1st stage: obtains face candidates.\n" - "2nd stage: refinement of face " - "candidates.\n3rd stage: further " - "refinement of face candidates.\n" - "Default is 0.6 0.7 0.7 " - "(MTCNN detector only)"}) - argument_list.append({"opts": ("-mtsc", "--mtcnn-scalefactor"), - "type": float, - "dest": "mtcnn_scalefactor", - "default": 0.709, - "help": "The scale factor for the image " - "pyramid. Should be a decimal number " - "less than one. Default is 0.709 " - "(MTCNN detector only)"}) argument_list.append({"opts": ("-r", "--rotate-images"), "type": str, "dest": "rotate_images", @@ -458,13 +463,15 @@ def get_optional_arguments(): "exactly what angles to check"}) argument_list.append({"opts": ("-bt", "--blur-threshold"), "type": float, + "action": Slider, + "min_max": (0.0, 100.0), + "rounding": 1, "dest": "blur_thresh", - "default": None, - "help": "Automatically discard images blurrier " - "than the specified threshold. " - "Discarded images are moved into a " - "\"blurry\" sub-folder. Lower values " - "allow more blur"}) + "default": 0.0, + "help": "Automatically discard images blurrier than the specified " + "threshold. Discarded images are moved into a \"blurry\" " + "sub-folder. Lower values allow more blur. Set to 0.0 to " + "turn off."}) argument_list.append({"opts": ("-mp", "--multiprocess"), "action": "store_true", "default": False, @@ -476,12 +483,13 @@ def get_optional_arguments(): "otherwise this is automatic."}) argument_list.append({"opts": ("-sz", "--size"), "type": int, + "action": Slider, + "min_max": (128, 512), "default": 256, - "help": "The output size of extracted faces. " - "Make sure that the model you intend " - "to train supports your required " - "size. This will only need to be " - "changed for hi-res models."}) + "rounding": 64, + "help": "The output size of extracted faces. Make sure that the " + "model you intend to train supports your required size. " + "This will only need to be changed for hi-res models."}) argument_list.append({"opts": ("-s", "--skip-existing"), "action": "store_true", "dest": "skip_existing", @@ -512,13 +520,15 @@ def get_optional_arguments(): argument_list.append({"opts": ("-si", "--save-interval"), "dest": "save_interval", "type": int, - "default": None, - "help": "Automatically save the alignments file " - "after a set amount of frames. Will " - "only save at the end of extracting by " - "default. WARNING: Don't interrupt the " - "script when writing the file because " - "it might get corrupted."}) + "action": Slider, + "min_max": (0, 1000), + "rounding": 10, + "default": 0, + "help": "Automatically save the alignments file after a set amount " + "of frames. Will only save at the end of extracting by " + "default. WARNING: Don't interrupt the script when writing " + "the file because it might get corrupted. Set to 0 to turn " + "off"}) return argument_list @@ -552,57 +562,73 @@ def get_optional_arguments(): "specified, all faces will be " "converted"}) argument_list.append({"opts": ("-t", "--trainer"), - "type": str, - # case sensitive because this is used to - # load a plug-in. + "type": str.lower, "choices": PluginLoader.get_available_models(), "default": PluginLoader.get_default_model(), "help": "Select the trainer that was used to " "create the model"}) argument_list.append({"opts": ("-c", "--converter"), - "type": str, - # case sensitive because this is used - # to load a plugin. - "choices": ("Masked", "Adjust"), - "default": "Masked", + "type": str.lower, + "choices": PluginLoader.get_available_converters(), + "default": "masked", "help": "Converter to use"}) + argument_list.append({ + "opts": ("-M", "--mask-type"), + "type": str.lower, + "dest": "mask_type", + "choices": ["rect", + "ellipse", + "smoothed", + "facehull", + "facehull_rect", + "dfl", + "cnn"], + "default": "facehull_rect", + "help": "R|Mask to use to replace faces." + "\nrect: Rectangle around face." + "\nellipse: Oval around face." + "\nsmoothed: Rectangle around face with smoothing." + "\nfacehull: Face cutout based on landmarks." + "\nfacehull_rect: Rectangle around faces with facehull" + "\n\tbetween the edges of the face and the background." + "\ndfl: A Face Hull mask from DeepFaceLabs." + "\ncnn: Not yet implemented"}) argument_list.append({"opts": ("-b", "--blur-size"), + "type": float, + "action": Slider, + "min_max": (0.0, 100.0), + "rounding": 2, + "default": 5.0, + "help": "Blur kernel size as a percentage of the swap area. Smooths " + "the transition between the swapped face and the background " + "image."}) + argument_list.append({"opts": ("-e", "--erosion-size"), + "dest": "erosion_size", + "type": float, + "action": Slider, + "min_max": (-100.0, 100.0), + "rounding": 2, + "default": 0.0, + "help": "Erosion kernel size as a percentage of the mask radius " + "area. Positive values apply erosion which reduces the size " + "of the swapped area. Negative values apply dilation which " + "increases the swapped area"}) + argument_list.append({"opts": ("-g", "--gpus"), "type": int, - "default": 2, - "help": "Blur size. (Masked converter only)"}) - argument_list.append({"opts": ("-e", "--erosion-kernel-size"), - "dest": "erosion_kernel_size", - "type": int, - "default": None, - "help": "Erosion kernel size. Positive values " - "apply erosion which reduces the edge " - "of the swapped face. Negative values " - "apply dilation which allows the " - "swapped face to cover more space. " - "(Masked converter only)"}) - argument_list.append({"opts": ("-M", "--mask-type"), - # lowercase this, because it's just a - # string later on. - "type": str.lower, - "dest": "mask_type", - "choices": ["rect", - "facehull", - "facehullandrect"], - "default": "facehullandrect", - "help": "Mask to use to replace faces. " - "(Masked converter only)"}) + "action": Slider, + "min_max": (1, 10), + "rounding": 1, + "default": 1, + "help": "Number of GPUs to use for conversion"}) argument_list.append({"opts": ("-sh", "--sharpen"), "type": str.lower, "dest": "sharpen_image", - "choices": ["bsharpen", "gsharpen"], + "choices": ["box_filter", "gaussian_filter"], "default": None, - "help": "Use Sharpen Image. bsharpen for Box " - "Blur, gsharpen for Gaussian Blur " - "(Masked converter only)"}) - argument_list.append({"opts": ("-g", "--gpus"), - "type": int, - "default": 1, - "help": "Number of GPUs to use for conversion"}) + "help": "Sharpen the masked facial region of " + "the converted images. Choice of filter " + "to use in sharpening process -- box" + "filter or gaussian filter."}) argument_list.append({"opts": ("-fr", "--frame-ranges"), "nargs": "+", "type": str, @@ -628,25 +654,25 @@ def get_optional_arguments(): "action": "store_true", "dest": "seamless_clone", "default": False, - "help": "Use cv2's seamless clone. " - "(Masked converter only)"}) + "help": "Use cv2's seamless clone function to " + "remove extreme gradients at the mask " + "seam by smoothing colors."}) argument_list.append({"opts": ("-mh", "--match-histogram"), "action": "store_true", "dest": "match_histogram", "default": False, - "help": "Use histogram matching. " - "(Masked converter only)"}) - argument_list.append({"opts": ("-sm", "--smooth-mask"), - "action": "store_true", - "dest": "smooth_mask", - "default": False, - "help": "Smooth mask (Adjust converter only)"}) + "help": "Adjust the histogram of each color " + "channel in the swapped reconstruction " + "to equal the histogram of the masked " + "area in the orginal image"}) argument_list.append({"opts": ("-aca", "--avg-color-adjust"), "action": "store_true", "dest": "avg_color_adjust", "default": False, - "help": "Average color adjust. " - "(Adjust converter only)"}) + "help": "Adjust the mean of each color channel " + " in the swapped reconstruction to " + "equal the mean of the masked area in " + "the orginal image"}) argument_list.append({"opts": ("-dt", "--draw-transparent"), "action": "store_true", "dest": "draw_transparent", @@ -667,18 +693,38 @@ def get_argument_list(): argument_list = list() argument_list.append({"opts": ("-A", "--input-A"), "action": DirFullPaths, - "dest": "input_A", - "default": "input_A", + "dest": "input_a", + "default": "input_a", "help": "Input directory. A directory " "containing training images for face A. " "Defaults to 'input'"}) argument_list.append({"opts": ("-B", "--input-B"), "action": DirFullPaths, - "dest": "input_B", - "default": "input_B", + "dest": "input_b", + "default": "input_b", "help": "Input directory. A directory " "containing training images for face B. " "Defaults to 'input'"}) + argument_list.append({"opts": ("-ala", "--alignments-A"), + "action": FileFullPaths, + "filetypes": 'alignments', + "type": str, + "dest": "alignments_path_a", + "default": None, + "help": "Path to alignments file for training set A. Only required " + "if you are using a masked model or warp-to-landmarks is " + "enabled. Defaults to /alignments.json if not " + "provided."}) + argument_list.append({"opts": ("-alb", "--alignments-B"), + "action": FileFullPaths, + "filetypes": 'alignments', + "type": str, + "dest": "alignments_path_b", + "default": None, + "help": "Path to alignments file for training set B. Only required " + "if you are using a masked model or warp-to-landmarks is " + "enabled. Defaults to /alignments.json if not " + "provided."}) argument_list.append({"opts": ("-m", "--model-dir"), "action": DirFullPaths, "dest": "model_dir", @@ -686,32 +732,51 @@ def get_argument_list(): "help": "Model directory. This is where the " "training data will be stored. " "Defaults to 'model'"}) - argument_list.append({"opts": ("-s", "--save-interval"), - "type": int, - "dest": "save_interval", - "default": 100, - "help": "Sets the number of iterations before " - "saving the model"}) argument_list.append({"opts": ("-t", "--trainer"), - "type": str, + "type": str.lower, "choices": PluginLoader.get_available_models(), "default": PluginLoader.get_default_model(), "help": "Select which trainer to use, Use " "LowMem for cards with less than 2GB of " "VRAM"}) + argument_list.append({"opts": ("-s", "--save-interval"), + "type": int, + "action": Slider, + "min_max": (10, 1000), + "rounding": 10, + "dest": "save_interval", + "default": 100, + "help": "Sets the number of iterations before saving the model"}) argument_list.append({"opts": ("-bs", "--batch-size"), "type": int, + "action": Slider, + "min_max": (2, 256), + "rounding": 2, + "dest": "batch_size", "default": 64, - "help": "Batch size, as a power of 2 " - "(64, 128, 256, etc)"}) + "help": "Batch size, as a power of 2 (64, 128, 256, etc)"}) argument_list.append({"opts": ("-it", "--iterations"), "type": int, + "action": Slider, + "min_max": (0, 5000000), + "rounding": 20000, "default": 1000000, - "help": "Length of training in iterations"}) + "help": "Length of training in iterations."}) argument_list.append({"opts": ("-g", "--gpus"), "type": int, + "action": Slider, + "min_max": (1, 10), + "rounding": 1, "default": 1, "help": "Number of GPUs to use for training"}) + argument_list.append({"opts": ("-ps", "--preview-scale"), + "type": int, + "action": Slider, + "dest": "preview_scale", + "min_max": (25, 200), + "rounding": 25, + "default": 100, + "help": "Percentage amount to scale the preview by."}) argument_list.append({"opts": ("-p", "--preview"), "action": "store_true", "dest": "preview", @@ -724,20 +789,39 @@ def get_argument_list(): "default": False, "help": "Writes the training result to a file " "even on preview mode"}) - argument_list.append({"opts": ("-pl", "--use-perceptual-loss"), - "action": "store_true", - "dest": "perceptual_loss", - "default": False, - "help": "Use perceptual loss while training"}) argument_list.append({"opts": ("-ag", "--allow-growth"), "action": "store_true", "dest": "allow_growth", "default": False, "help": "Sets allow_growth option of Tensorflow " "to spare memory on some configs"}) + argument_list.append({"opts": ("-nl", "--no-logs"), + "action": "store_true", + "dest": "no_logs", + "default": False, + "help": "Disables TensorBoard logging. NB: Disabling logs means " + "that you will not be able to use the graph or analysis " + "for this session in the GUI."}) + argument_list.append({"opts": ("-wl", "--warp-to-landmarks"), + "action": "store_true", + "dest": "warp_to_landmarks", + "default": False, + "help": "Warps training faces to closely matched Landmarks from the " + "opposite face-set rather than randomly warping the face. " + "This is the 'dfaker' way of doing warping. Alignments " + "files for both sets of faces must be provided if using " + "this option."}) + argument_list.append({"opts": ("-nf", "--no-flip"), + "action": "store_true", + "dest": "no_flip", + "default": False, + "help": "To effectively learn, a random set of images are flipped " + "horizontally. Sometimes it is desirable for this not to " + "occur. Generally this should be left off except for " + "during 'fit training'."}) argument_list.append({"opts": ("-tia", "--timelapse-input-A"), "action": DirFullPaths, - "dest": "timelapse_input_A", + "dest": "timelapse_input_a", "default": None, "help": "For if you want a timelapse: " "The input folder for the timelapse. " @@ -748,7 +832,7 @@ def get_argument_list(): "--timelapse-input-B parameter."}) argument_list.append({"opts": ("-tib", "--timelapse-input-B"), "action": DirFullPaths, - "dest": "timelapse_input_B", + "dest": "timelapse_input_b", "default": None, "help": "For if you want a timelapse: " "The input folder for the timelapse. " @@ -765,13 +849,6 @@ def get_argument_list(): "If the input folders are supplied but " "no output folder, it will default to " "your model folder /timelapse/"}) - # This is a hidden argument to indicate that the GUI is being used, - # so the preview window should be redirected Accordingly - argument_list.append({"opts": ("-gui", "--gui"), - "action": "store_true", - "dest": "redirect_gui", - "default": False, - "help": argparse.SUPPRESS}) return argument_list diff --git a/lib/config.py b/lib/config.py new file mode 100644 index 0000000000..fa4d9af215 --- /dev/null +++ b/lib/config.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +""" Default configurations for faceswap + Extends out configparser funcionality + by checking for default config updates + and returning data in it's correct format """ + +import logging +import os +import sys +from collections import OrderedDict +from configparser import ConfigParser + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class FaceswapConfig(): + """ Config Items """ + def __init__(self, section): + """ Init Configuration """ + logger.debug("Initializing: %s", self.__class__.__name__) + self.configfile = self.get_config_file() + self.config = ConfigParser(allow_no_value=True) + self.defaults = OrderedDict() + self.config.optionxform = str + self.section = section + + self.set_defaults() + self.handle_config() + logger.debug("Initialized: %s", self.__class__.__name__) + + def set_defaults(self): + """ Override for plugin specific config defaults + + Should be a series of self.add_section() and self.add_item() calls + + e.g: + + section = "sect_1" + self.add_section(title=section, + info="Section 1 Information") + + self.add_item(section=section, + title="option_1", + datatype=bool, + default=False, + info="sect_1 option_1 information") + """ + raise NotImplementedError + + @property + def config_dict(self): + """ Collate global options and requested section into a dictionary + with the correct datatypes """ + conf = dict() + for sect in ("global", self.section): + if sect not in self.config.sections(): + continue + for key in self.config[sect]: + if key.startswith(("#", "\n")): # Skip comments + continue + conf[key] = self.get(sect, key) + return conf + + def get(self, section, option): + """ Return a config item in it's correct format """ + logger.debug("Getting config item: (section: '%s', option: '%s')", section, option) + datatype = self.defaults[section][option]["type"] + if datatype == bool: + func = self.config.getboolean + elif datatype == int: + func = self.config.getint + elif datatype == float: + func = self.config.getfloat + else: + func = self.config.get + retval = func(section, option) + if isinstance(retval, str) and retval.lower() == "none": + retval = None + logger.debug("Returning item: (type: %s, value: %s)", datatype, retval) + return retval + + def get_config_file(self): + """ Return the config file from the calling folder """ + dirname = os.path.dirname(sys.modules[self.__module__].__file__) + folder, fname = os.path.split(dirname) + retval = os.path.join(os.path.dirname(folder), "config", "{}.ini".format(fname)) + logger.debug("Config File location: '%s'", retval) + return retval + + def add_section(self, title=None, info=None): + """ Add a default section to config file """ + logger.debug("Add section: (title: '%s', info: '%s')", title, info) + if None in (title, info): + raise ValueError("Default config sections must have a title and " + "information text") + self.defaults[title] = OrderedDict() + self.defaults[title]["helptext"] = info + + def add_item(self, section=None, title=None, datatype=str, + default=None, info=None, rounding=None, min_max=None, choices=None): + """ Add a default item to a config section + + For int or float values, rounding and min_max must be set + This is for the slider in the GUI. The min/max values are not enforced: + rounding: sets the decimal places for floats or the step interval for ints. + min_max: tuple of min and max accepted values + + For str values choices can be set to validate input and create a combo box + in the GUI + + """ + logger.debug("Add item: (section: '%s', title: '%s', datatype: '%s', default: '%s', " + "info: '%s', rounding: '%s', min_max: %s, choices: %s)", + section, title, datatype, default, info, rounding, min_max, choices) + + choices = list() if not choices else choices + + if None in (section, title, default, info): + raise ValueError("Default config items must have a section, " + "title, defult and " + "information text") + if not self.defaults.get(section, None): + raise ValueError("Section does not exist: {}".format(section)) + if datatype not in (str, bool, float, int): + raise ValueError("'datatype' must be one of str, bool, float or " + "int: {} - {}".format(section, title)) + if datatype in (float, int) and (rounding is None or min_max is None): + raise ValueError("'rounding' and 'min_max' must be set for numerical options") + if not isinstance(choices, (list, tuple)): + raise ValueError("'choices' must be a list or tuple") + self.defaults[section][title] = {"default": default, + "helptext": info, + "type": datatype, + "rounding": rounding, + "min_max": min_max, + "choices": choices} + + def check_exists(self): + """ Check that a config file exists """ + if not os.path.isfile(self.configfile): + logger.debug("Config file does not exist: '%s'", self.configfile) + return False + logger.debug("Config file exists: '%s'", self.configfile) + return True + + def create_default(self): + """ Generate a default config if it does not exist """ + logger.debug("Creating default Config") + for section, items in self.defaults.items(): + logger.debug("Adding section: '%s')", section) + self.insert_config_section(section, items["helptext"]) + for item, opt in items.items(): + logger.debug("Adding option: (item: '%s', opt: '%s'", item, opt) + if item == "helptext": + continue + self.insert_config_item(section, + item, + opt["default"], + opt) + self.save_config() + + def insert_config_section(self, section, helptext, config=None): + """ Insert a section into the config """ + logger.debug("Inserting section: (section: '%s', helptext: '%s', config: '%s')", + section, helptext, config) + config = self.config if config is None else config + helptext = self.format_help(helptext, is_section=True) + config.add_section(section) + config.set(section, helptext) + logger.debug("Inserted section: '%s'", section) + + def insert_config_item(self, section, item, default, option, + config=None): + """ Insert an item into a config section """ + logger.debug("Inserting item: (section: '%s', item: '%s', default: '%s', helptext: '%s', " + "config: '%s')", section, item, default, option["helptext"], config) + config = self.config if config is None else config + helptext = option["helptext"] + helptext += self.set_helptext_choices(option) + helptext += "\n[Default: {}]".format(default) + helptext = self.format_help(helptext, is_section=False) + config.set(section, helptext) + config.set(section, item, str(default)) + logger.debug("Inserted item: '%s'", item) + + @staticmethod + def set_helptext_choices(option): + """ Set the helptext choices """ + choices = "" + if option["choices"]: + choices = "\nChoose from: {}".format(option["choices"]) + elif option["type"] == bool: + choices = "\nChoose from: True, False" + elif option["type"] == int: + cmin, cmax = option["min_max"] + choices = "\nSelect an integer between {} and {}".format(cmin, cmax) + elif option["type"] == float: + cmin, cmax = option["min_max"] + choices = "\nSelect a decimal number between {} and {}".format(cmin, cmax) + return choices + + @staticmethod + def format_help(helptext, is_section=False): + """ Format comments for default ini file """ + logger.debug("Formatting help: (helptext: '%s', is_section: '%s')", helptext, is_section) + helptext = '# {}'.format(helptext.replace("\n", "\n# ")) + if is_section: + helptext = helptext.upper() + else: + helptext = "\n{}".format(helptext) + logger.debug("formatted help: '%s'", helptext) + return helptext + + def load_config(self): + """ Load values from config """ + logger.info("Loading config: '%s'", self.configfile) + self.config.read(self.configfile) + + def save_config(self): + """ Save a config file """ + logger.info("Updating config at: '%s'", self.configfile) + f_cfgfile = open(self.configfile, "w") + self.config.write(f_cfgfile) + f_cfgfile.close() + + def validate_config(self): + """ Check for options in default config against saved config + and add/remove as appropriate """ + logger.debug("Validating config") + if self.check_config_change(): + self.add_new_config_items() + self.check_config_choices() + logger.debug("Validated config") + + def add_new_config_items(self): + """ Add new items to the config file """ + logger.debug("Updating config") + new_config = ConfigParser(allow_no_value=True) + for section, items in self.defaults.items(): + self.insert_config_section(section, items["helptext"], new_config) + for item, opt in items.items(): + if item == "helptext": + continue + if section not in self.config.sections(): + logger.debug("Adding new config section: '%s'", section) + opt_value = opt["default"] + else: + opt_value = self.config[section].get(item, opt["default"]) + self.insert_config_item(section, + item, + opt_value, + opt, + new_config) + self.config = new_config + self.config.optionxform = str + self.save_config() + logger.debug("Updated config") + + def check_config_choices(self): + """ Check that config items are valid choices """ + logger.debug("Checking config choices") + for section, items in self.defaults.items(): + for item, opt in items.items(): + if item == "helptext" or not opt["choices"]: + continue + opt_value = self.config.get(section, item) + if opt_value.lower() == "none" and any(choice.lower() == "none" + for choice in opt["choices"]): + continue + if opt_value not in opt["choices"]: + default = str(opt["default"]) + logger.warning("'%s' is not a valid config choice for '%s': '%s'. Defaulting " + "to: '%s'", opt_value, section, item, default) + self.config.set(section, item, default) + logger.debug("Checked config choices") + + def check_config_change(self): + """ Check whether new default items have been added or removed + from the config file compared to saved version """ + if set(self.config.sections()) != set(self.defaults.keys()): + logger.debug("Default config has new section(s)") + return True + + for section, items in self.defaults.items(): + opts = [opt for opt in items.keys() if opt != "helptext"] + exists = [opt for opt in self.config[section].keys() + if not opt.startswith(("# ", "\n# "))] + if set(exists) != set(opts): + logger.debug("Default config has new item(s)") + return True + logger.debug("Default config has not changed") + return False + + def handle_config(self): + """ Handle the config """ + logger.debug("Handling config") + if not self.check_exists(): + self.create_default() + self.load_config() + self.validate_config() + logger.debug("Handled config") diff --git a/lib/faces_detect.py b/lib/faces_detect.py index 4082efbc4a..c0fa2995c9 100644 --- a/lib/faces_detect.py +++ b/lib/faces_detect.py @@ -3,7 +3,7 @@ import logging from dlib import rectangle as d_rectangle # pylint: disable=no-name-in-module -from lib.aligner import Extract as AlignerExtract, get_align_mat +from lib.aligner import Extract as AlignerExtract, get_align_mat, get_matrix_scaling logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -12,14 +12,13 @@ class DetectedFace(): """ Detected face and landmark information """ def __init__( # pylint: disable=invalid-name self, image=None, x=None, w=None, y=None, h=None, - frame_dims=None, landmarksXY=None): + landmarksXY=None): logger.trace("Initializing %s", self.__class__.__name__) self.image = image self.x = x self.w = w self.y = y self.h = h - self.frame_dims = frame_dims self.landmarksXY = landmarksXY self.hash = None # Hash must be set when the file is saved due to image compression @@ -63,17 +62,12 @@ def image_to_face(self, image): self.x: self.x + self.w] def to_alignment(self): - """ Convert a detected face to alignment dict - - NB: frame_dims should be the height and width - of the original frame. """ - + """ Convert a detected face to alignment dict """ alignment = dict() alignment["x"] = self.x alignment["w"] = self.w alignment["y"] = self.y alignment["h"] = self.h - alignment["frame_dims"] = self.frame_dims alignment["landmarksXY"] = self.landmarksXY alignment["hash"] = self.hash logger.trace("Returning: %s", alignment) @@ -87,23 +81,22 @@ def from_alignment(self, alignment, image=None): self.w = alignment["w"] self.y = alignment["y"] self.h = alignment["h"] - self.frame_dims = alignment["frame_dims"] self.landmarksXY = alignment["landmarksXY"] # Manual tool does not know the final hash so default to None self.hash = alignment.get("hash", None) if image is not None and image.any(): self.image_to_face(image) logger.trace("Created from alignment: (x: %s, w: %s, y: %s. h: %s, " - "frame_dims: %s, landmarks: %s)", - self.x, self.w, self.y, self.h, self.frame_dims, self.landmarksXY) + "landmarks: %s)", + self.x, self.w, self.y, self.h, self.landmarksXY) # <<< Aligned Face methods and properties >>> # - def load_aligned(self, image, size=256, padding=48, align_eyes=False): + def load_aligned(self, image, size=256, align_eyes=False): """ No need to load aligned information for all uses of this class, so only call this to load the information for easy reference to aligned properties for this face """ - logger.trace("Loading aligned face: (size: %s, padding: %s, align_eyes: %s)", - size, padding, align_eyes) + logger.trace("Loading aligned face: (size: %s, align_eyes: %s)", size, align_eyes) + padding = int(size * 0.1875) self.aligned["size"] = size self.aligned["padding"] = padding self.aligned["align_eyes"] = align_eyes @@ -153,3 +146,8 @@ def adjusted_matrix(self): self.aligned["padding"]) logger.trace("Returning: %s", mat) return mat + + @property + def adjusted_interpolators(self): + """ Return the interpolator and reverse interpolator for the adjusted matrix """ + return get_matrix_scaling(self.adjusted_matrix) diff --git a/lib/gui/__init__.py b/lib/gui/__init__.py index e7da147079..dca41f2ae0 100644 --- a/lib/gui/__init__.py +++ b/lib/gui/__init__.py @@ -1,7 +1,9 @@ from lib.gui.command import CommandNotebook from lib.gui.display import DisplayNotebook -from lib.gui.options import CliOptions, Config -from lib.gui.stats import CurrentSession +from lib.gui.options import CliOptions +from lib.gui.menu import MainMenuBar +from lib.gui.popup_configure import popup_config +from lib.gui.stats import Session from lib.gui.statusbar import StatusBar -from lib.gui.utils import ConsoleOut, Images +from lib.gui.utils import ConsoleOut, get_config, get_images, initialize_config, initialize_images from lib.gui.wrapper import ProcessWrapper diff --git a/lib/gui/command.py b/lib/gui/command.py index a7744d25da..2b6e608be6 100644 --- a/lib/gui/command.py +++ b/lib/gui/command.py @@ -5,44 +5,42 @@ import tkinter as tk from tkinter import ttk -from .options import Config from .tooltip import Tooltip -from .utils import ContextMenu, Images, FileHandler +from .utils import ContextMenu, FileHandler, get_images, get_config, set_slider_rounding logger = logging.getLogger(__name__) # pylint: disable=invalid-name -class CommandNotebook(ttk.Notebook): +class CommandNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors """ Frame to hold each individual tab of the command notebook """ - def __init__(self, parent, cli_options, tk_vars, scaling_factor): - logger.debug("Initializing %s: (parent: %s, cli_options: %s, tk_vars: %s, " - "scaling_factor: %s", self.__class__.__name__, parent, cli_options, - tk_vars, scaling_factor) + def __init__(self, parent): + logger.debug("Initializing %s: (parent: %s)", self.__class__.__name__, parent) + scaling_factor = get_config().scaling_factor width = int(420 * scaling_factor) height = int(500 * scaling_factor) ttk.Notebook.__init__(self, parent, width=width, height=height) parent.add(self) - self.cli_opts = cli_options - self.tk_vars = tk_vars self.actionbtns = dict() - self.set_running_task_trace() self.build_tabs() + get_config().command_notebook = self logger.debug("Initialized %s", self.__class__.__name__) def set_running_task_trace(self): """ Set trigger action for the running task to change the action buttons text and command """ logger.debug("Set running trace") - self.tk_vars["runningtask"].trace("w", self.change_action_button) + tk_vars = get_config().tk_vars + tk_vars["runningtask"].trace("w", self.change_action_button) def build_tabs(self): """ Build the tabs for the relevant command """ logger.debug("Build Tabs") - for category in self.cli_opts.categories: - cmdlist = self.cli_opts.commands[category] + cli_opts = get_config().cli_opts + for category in cli_opts.categories: + cmdlist = cli_opts.commands[category] for command in cmdlist: title = command.title() commandtab = CommandTab(self, category, command) @@ -52,9 +50,11 @@ def build_tabs(self): def change_action_button(self, *args): """ Change the action button to relevant control """ logger.debug("Update Action Buttons: (args: %s", args) + tk_vars = get_config().tk_vars + for cmd in self.actionbtns.keys(): btnact = self.actionbtns[cmd] - if self.tk_vars["runningtask"].get(): + if tk_vars["runningtask"].get(): ttl = "Terminate" hlp = "Exit the running process" else: @@ -65,7 +65,7 @@ def change_action_button(self, *args): Tooltip(btnact, text=hlp, wraplength=200) -class CommandTab(ttk.Frame): +class CommandTab(ttk.Frame): # pylint: disable=too-many-ancestors """ Frame to hold each individual tab of the command notebook """ def __init__(self, parent, category, command): @@ -74,9 +74,7 @@ def __init__(self, parent, category, command): ttk.Frame.__init__(self, parent) self.category = category - self.cli_opts = parent.cli_opts self.actionbtns = parent.actionbtns - self.tk_vars = parent.tk_vars self.command = command self.build_tab() @@ -100,7 +98,7 @@ def add_frame_separator(self): logger.debug("Added frame seperator") -class OptionsFrame(ttk.Frame): +class OptionsFrame(ttk.Frame): # pylint: disable=too-many-ancestors """ Options Frame - Holds the Options for each command """ def __init__(self, parent): @@ -108,7 +106,6 @@ def __init__(self, parent): ttk.Frame.__init__(self, parent) self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - self.opts = parent.cli_opts self.command = parent.command self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) @@ -121,7 +118,8 @@ def __init__(self, parent): self.chkbtns = self.checkbuttons_frame() self.build_frame() - self.opts.set_context_option(self.command) + cli_opts = get_config().cli_opts + cli_opts.set_context_option(self.command) logger.debug("Initialized %s", self.__class__.__name__) def checkbuttons_frame(self): @@ -150,7 +148,8 @@ def build_frame(self): self.add_scrollbar() self.canvas.bind("", self.resize_frame) - for option in self.opts.gen_command_options(self.command): + cli_opts = get_config().cli_opts + for option in cli_opts.gen_command_options(self.command): optioncontrol = OptionControl(self.command, option, self.optsframe, @@ -170,7 +169,7 @@ def add_scrollbar(self): self.optsframe.bind("", self.update_scrollbar) logger.debug("Added Options Scrollbar") - def update_scrollbar(self, event): + def update_scrollbar(self, event): # pylint: disable=unused-argument """ Update the options frame scrollbar """ self.canvas.configure(scrollregion=self.canvas.bbox("all")) @@ -207,6 +206,7 @@ def build_full_control(self): if ctl == ttk.Checkbutton: dflt = self.option.get("default", False) choices = self.option["choices"] if ctl == ttk.Combobox else None + min_max = self.option["min_max"] if ctl == ttk.Scale else None ctlframe = self.build_one_control_frame() @@ -217,6 +217,7 @@ def build_full_control(self): self.option["value"] = self.build_one_control(ctlframe, ctlvars, choices, + min_max, sysbrowser) logger.debug("Built option control") @@ -228,6 +229,7 @@ def format_help(self, ctltitle): ctlhelp = ctlhelp[2:].replace("\n\t", " ").replace("\n'", "\n\n'") else: ctlhelp = " ".join(ctlhelp.split()) + ctlhelp = ctlhelp.replace("%%", "%") ctlhelp = ". ".join(i.capitalize() for i in ctlhelp.split(". ")) ctlhelp = ctltitle + " - " + ctlhelp logger.debug("Formatted control help: (title: '%s', help: '%s'", ctltitle, ctlhelp) @@ -249,15 +251,14 @@ def build_one_control_label(frame, control_title): lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) logger.debug("Built control label: '%s'", control_title) - def build_one_control(self, frame, controlvars, choices, sysbrowser): + def build_one_control(self, frame, controlvars, choices, min_max, sysbrowser): """ Build and place the option controls """ - logger.debug("Build control: (controlvars: %s, choices: %s, sysbrowser: %s", - controlvars, choices, sysbrowser) + logger.debug("Build control: (controlvars: %s, choices: %s, min_max: %s, sysbrowser: %s", + controlvars, choices, min_max, sysbrowser) control, control_title, default, helptext = controlvars default = default if default is not None else "" - var = tk.BooleanVar( - frame) if control == ttk.Checkbutton else tk.StringVar(frame) + var = tk.BooleanVar(frame) if control == ttk.Checkbutton else tk.StringVar(frame) var.set(default) if sysbrowser: @@ -268,6 +269,12 @@ def build_one_control(self, frame, controlvars, choices, sysbrowser): control_title, var, helptext) + elif control == ttk.Scale: + self.slider_control(control, + frame, + var, + min_max, + helptext) else: self.control_to_optionsframe(control, frame, @@ -292,6 +299,29 @@ def checkbutton_to_checkframe(self, control, control_title, var, helptext): Tooltip(ctl, text=helptext, wraplength=200) logger.debug("Added control checkframe: '%s'", control_title) + def slider_control(self, control, frame, tk_var, min_max, helptext): + """ A slider control with corresponding Entry box """ + logger.debug("Add slider control to Options Frame: %s", control) + d_type = self.option.get("type", float) + rnd = self.option.get("rounding", 2) if d_type == float else self.option.get("rounding", 1) + + tbox = ttk.Entry(frame, width=8, textvariable=tk_var, justify=tk.RIGHT) + tbox.pack(padx=(0, 5), side=tk.RIGHT) + ctl = control( + frame, + variable=tk_var, + command=lambda val, var=tk_var, dt=d_type, rn=rnd, mm=min_max: + set_slider_rounding(val, var, dt, rn, mm)) + ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) + rc_menu = ContextMenu(ctl) + rc_menu.cm_bind() + ctl["from_"] = min_max[0] + ctl["to"] = min_max[1] + + Tooltip(ctl, text=helptext, wraplength=720) + Tooltip(tbox, text=helptext, wraplength=720) + logger.debug("Added slider control to Options Frame: %s", control) + @staticmethod def control_to_optionsframe(control, frame, var, choices, helptext): """ Standard non-check buttons sit in the main options frame """ @@ -303,8 +333,7 @@ def control_to_optionsframe(control, frame, var, choices, helptext): if control == ttk.Combobox: logger.debug("Adding combo choices: %s", choices) ctl["values"] = [choice for choice in choices] - - Tooltip(ctl, text=helptext, wraplength=720) + Tooltip(ctl, text=helptext, wraplength=920) logger.debug("Added control to Options Frame: %s", control) def add_browser_buttons(self, frame, sysbrowser, filepath): @@ -312,7 +341,7 @@ def add_browser_buttons(self, frame, sysbrowser, filepath): logger.debug("Adding browser buttons: (sysbrowser: '%s', filepath: '%s'", sysbrowser, filepath) for browser in sysbrowser: - img = Images().icons[browser] + img = get_images().icons[browser] action = getattr(self, "ask_" + browser) filetypes = self.option.get("filetypes", "default") fileopn = ttk.Button(frame, @@ -351,7 +380,7 @@ def ask_save(filepath, filetypes=None): filepath.set(filename) @staticmethod - def ask_nothing(filepath, filetypes=None): + def ask_nothing(filepath, filetypes=None): # pylint: disable=unused-argument """ Method that does nothing, used for disabling open/save pop up """ return @@ -370,7 +399,7 @@ def ask_context(self, filepath, filetypes): filepath.set(filename) -class ActionFrame(ttk.Frame): +class ActionFrame(ttk.Frame): # pylint: disable=too-many-ancestors """Action Frame - Displays action controls for the command tab """ def __init__(self, parent): @@ -382,16 +411,16 @@ def __init__(self, parent): self.title = self.command.title() self.add_action_button(parent.category, - parent.actionbtns, - parent.tk_vars) - self.add_util_buttons(parent.cli_opts, parent.tk_vars) + parent.actionbtns) + self.add_util_buttons() logger.debug("Initialized %s", self.__class__.__name__) - def add_action_button(self, category, actionbtns, tk_vars): + def add_action_button(self, category, actionbtns): """ Add the action buttons for page """ logger.debug("Add action buttons: '%s'", self.title) actframe = ttk.Frame(self) actframe.pack(fill=tk.X, side=tk.LEFT) + tk_vars = get_config().tk_vars var_value = "{},{}".format(category, self.command) @@ -415,17 +444,17 @@ def add_action_button(self, category, actionbtns, tk_vars): wraplength=200) logger.debug("Added action buttons: '%s'", self.title) - def add_util_buttons(self, cli_options, tk_vars): + def add_util_buttons(self): """ Add the section utility buttons """ logger.debug("Add util buttons") utlframe = ttk.Frame(self) utlframe.pack(side=tk.RIGHT) - config = Config(cli_options, tk_vars) + config = get_config() for utl in ("load", "save", "clear", "reset"): logger.debug("Adding button: '%s'", utl) - img = Images().icons[utl] - action_cls = config if utl in (("save", "load")) else cli_options + img = get_images().icons[utl] + action_cls = config if utl in (("save", "load")) else config.cli_opts action = getattr(action_cls, utl) btnutl = ttk.Button(utlframe, image=img, diff --git a/lib/gui/display.py b/lib/gui/display.py index 650428b3d9..bb70b63423 100644 --- a/lib/gui/display.py +++ b/lib/gui/display.py @@ -4,48 +4,56 @@ What is displayed in the Display Frame varies depending on what tasked is being run """ +import logging import tkinter as tk from tkinter import ttk from .display_analysis import Analysis from .display_command import GraphDisplay, PreviewExtract, PreviewTrain +from .utils import get_config +logger = logging.getLogger(__name__) # pylint: disable=invalid-name -class DisplayNotebook(ttk.Notebook): + +class DisplayNotebook(ttk.Notebook): # pylint: disable=too-many-ancestors """ The display tabs """ - def __init__(self, parent, session, tk_vars, scaling_factor): + def __init__(self, parent): + logger.debug("Initializing %s", self.__class__.__name__) ttk.Notebook.__init__(self, parent, width=780) parent.add(self) - + tk_vars = get_config().tk_vars self.wrapper_var = tk_vars["display"] self.runningtask = tk_vars["runningtask"] - self.session = session self.set_wrapper_var_trace() - self.add_static_tabs(scaling_factor) + self.add_static_tabs() self.static_tabs = [child for child in self.tabs()] + logger.debug("Initialized %s", self.__class__.__name__) def set_wrapper_var_trace(self): """ Set the trigger actions for the display vars when they have been triggered in the Process Wrapper """ + logger.debug("Setting wrapper var trace") self.wrapper_var.trace("w", self.update_displaybook) - def add_static_tabs(self, scaling_factor): + def add_static_tabs(self): """ Add tabs that are permanently available """ + logger.debug("Adding static tabs") for tab in ("job queue", "analysis"): if tab == "job queue": continue # Not yet implemented if tab == "analysis": helptext = {"stats": "Summary statistics for each training session"} - frame = Analysis(self, tab, helptext, scaling_factor) + frame = Analysis(self, tab, helptext) else: frame = self.add_frame() self.add(frame, text=tab.title()) def add_frame(self): """ Add a single frame for holding tab's contents """ + logger.debug("Adding frame") frame = ttk.Frame(self) frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5) return frame @@ -58,32 +66,53 @@ def command_display(self, command): def extract_tabs(self): """ Build the extract tabs """ + logger.debug("Build extract tabs") helptext = ("Updates preview from output every 5 " "seconds to limit disk contention") PreviewExtract(self, "preview", helptext, 5000) + logger.debug("Built extract tabs") def train_tabs(self): """ Build the train tabs """ + logger.debug("Build train tabs") for tab in ("graph", "preview"): if tab == "graph": helptext = "Graph showing Loss vs Iterations" GraphDisplay(self, "graph", helptext, 5000) elif tab == "preview": helptext = "Training preview. Updated on every save iteration" - PreviewTrain(self, "preview", helptext, 5000) + PreviewTrain(self, "preview", helptext, 1000) + logger.debug("Built train tabs") def convert_tabs(self): """ Build the convert tabs Currently identical to Extract, so just call that """ + logger.debug("Build convert tabs") self.extract_tabs() + logger.debug("Built convert tabs") def remove_tabs(self): """ Remove all command specific tabs """ for child in self.tabs(): - if child not in self.static_tabs: - self.forget(child) - - def update_displaybook(self, *args): + if child in self.static_tabs: + continue + logger.debug("removing child: %s", child) + child_name = child.split(".")[-1] + child_object = self.children[child_name] + self.destroy_tabs_children(child_object) + self.forget(child) + + @staticmethod + def destroy_tabs_children(tab): + """ Destroy all tabs children + Children must be destroyed as forget only hides display + """ + logger.debug("Destroying children for tab: %s", tab) + for child in tab.winfo_children(): + logger.debug("Destroying child: %s", child) + child.destroy() + + def update_displaybook(self, *args): # pylint: disable=unused-argument """ Set the display tabs based on executing task """ command = self.wrapper_var.get() self.remove_tabs() diff --git a/lib/gui/display_analysis.py b/lib/gui/display_analysis.py index 75aff1fd5d..8d02c458c9 100644 --- a/lib/gui/display_analysis.py +++ b/lib/gui/display_analysis.py @@ -2,104 +2,138 @@ """ Analysis tab of Display Frame of the Faceswap GUI """ import csv +import logging +import os import tkinter as tk from tkinter import ttk from .display_graph import SessionGraph from .display_page import DisplayPage -from .stats import Calculations, SavedSessions, SessionsSummary, SessionsTotals +from .stats import Calculations, Session from .tooltip import Tooltip -from .utils import Images, FileHandler +from .utils import FileHandler, get_config, get_images + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class Analysis(DisplayPage): # pylint: disable=too-many-ancestors """ Session analysis tab """ - def __init__(self, parent, tabname, helptext, scaling_factor): - DisplayPage.__init__(self, parent, tabname, helptext) + def __init__(self, parent, tabname, helptext): + logger.debug("Initializing: %s: (parent, %s, tabname: '%s', helptext: '%s')", + self.__class__.__name__, parent, tabname, helptext) + super().__init__(parent, tabname, helptext) self.summary = None + self.session = None self.add_options() - self.add_main_frame(scaling_factor) + self.add_main_frame() + logger.debug("Initialized: %s", self.__class__.__name__) def set_vars(self): """ Analysis specific vars """ selected_id = tk.StringVar() - filename = tk.StringVar() - return {"selected_id": selected_id, - "filename": filename} + return {"selected_id": selected_id} - def add_main_frame(self, scaling_factor): + def add_main_frame(self): """ Add the main frame to the subnotebook to hold stats and session data """ + logger.debug("Adding main frame") mainframe = self.subnotebook_add_page("stats") self.stats = StatsData(mainframe, - self.vars["filename"], self.vars["selected_id"], - self.helptext["stats"], - scaling_factor) + self.helptext["stats"]) + logger.debug("Added main frame") def add_options(self): """ Add the options bar """ + logger.debug("Adding options") self.reset_session_info() options = Options(self) options.add_options() + logger.debug("Added options") def reset_session_info(self): """ Reset the session info status to default """ - self.vars["filename"].set(None) + logger.debug("Resetting session info") self.set_info("No session data loaded") def load_session(self): """ Load previously saved sessions """ + logger.debug("Loading session") self.clear_session() - filename = FileHandler("open", "session").retfile - if not filename: + fullpath = FileHandler("filename", "state").retfile + if not fullpath: + return + logger.debug("state_file: '%s'", fullpath) + model_dir, state_file = os.path.split(fullpath) + logger.debug("model_dir: '%s'", model_dir) + model_name = self.get_model_name(model_dir, state_file) + if not model_name: return - filename = filename.name - loaded_data = SavedSessions(filename).sessions - msg = filename - if len(filename) > 70: - msg = "...{}".format(filename[-70:]) - self.set_session_summary(loaded_data, msg) - self.vars["filename"].set(filename) + self.session = Session(model_dir=model_dir, model_name=model_name) + self.session.initialize_session(is_training=False) + msg = os.path.split(state_file)[0] + if len(msg) > 70: + msg = "...{}".format(msg[-70:]) + self.set_session_summary(msg) + + @staticmethod + def get_model_name(model_dir, state_file): + """ Get the state file from the model directory """ + logger.debug("Getting model name") + model_name = state_file.replace("_state.json", "") + logger.debug("model_name: %s", model_name) + logs_dir = os.path.join(model_dir, "{}_logs".format(model_name)) + if not os.path.isdir(logs_dir): + logger.warning("No logs folder found in folder: '%s'", logs_dir) + return None + return model_name def reset_session(self): - """ Load previously saved sessions """ + """ Reset currently training sessions """ + logger.debug("Reset current training session") self.clear_session() - if self.session.stats["iterations"] == 0: + session = get_config().session + if not session.initialized: + logger.debug("Training not running") print("Training not running") return - loaded_data = self.session.historical.sessions msg = "Currently running training session" - self.set_session_summary(loaded_data, msg) - self.vars["filename"].set("Currently running training session") + self.session = session + self.set_session_summary(msg) - def set_session_summary(self, data, message): + def set_session_summary(self, message): """ Set the summary data and info message """ - self.summary = SessionsSummary(data).summary + logger.debug("Setting session summary. (message: '%s')", message) + self.summary = self.session.full_summary self.set_info("Session: {}".format(message)) - self.stats.loaded_data = data + self.stats.session = self.session self.stats.tree_insert_data(self.summary) def clear_session(self): """ Clear sessions stats """ + logger.debug("Clearing session") self.summary = None - self.stats.loaded_data = None + self.stats.session = None self.stats.tree_clear() self.reset_session_info() def save_session(self): """ Save sessions stats to csv """ + logger.debug("Saving session") if not self.summary: + logger.debug("No summary data loaded. Nothing to save") print("No summary data loaded. Nothing to save") return savefile = FileHandler("save", "csv").retfile if not savefile: + logger.debug("No save file. Returning") return write_dicts = [val for val in self.summary.values()] fieldnames = sorted(key for key in write_dicts[0].keys()) + logger.debug("Saving to: '%s'", savefile) with savefile as outfile: csvout = csv.DictWriter(outfile, fieldnames) csvout.writeheader() @@ -110,8 +144,10 @@ def save_session(self): class Options(): """ Options bar of Analysis tab """ def __init__(self, parent): + logger.debug("Initializing: %s", self.__class__.__name__) self.optsframe = parent.optsframe self.parent = parent + logger.debug("Initialized: %s", self.__class__.__name__) def add_options(self): """ Add the display tab options """ @@ -120,9 +156,10 @@ def add_options(self): def add_buttons(self): """ Add the option buttons """ for btntype in ("reset", "clear", "save", "load"): + logger.debug("Adding button: '%s'", btntype) cmd = getattr(self.parent, "{}_session".format(btntype)) btn = ttk.Button(self.optsframe, - image=Images().icons[btntype], + image=get_images().icons[btntype], command=cmd) btn.pack(padx=2, side=tk.RIGHT) hlp = self.set_help(btntype) @@ -131,6 +168,7 @@ def add_buttons(self): @staticmethod def set_help(btntype): """ Set the helptext for option buttons """ + logger.debug("Setting help") hlp = "" if btntype == "reset": hlp = "Load/Refresh stats for the currently training session" @@ -145,25 +183,15 @@ def set_help(btntype): class StatsData(ttk.Frame): # pylint: disable=too-many-ancestors """ Stats frame of analysis tab """ - def __init__(self, - parent, - filename, - selected_id, - helptext, - scaling_factor): - ttk.Frame.__init__(self, parent) - self.pack(side=tk.TOP, - padx=5, - pady=5, - expand=True, - fill=tk.X, - anchor=tk.N) - - self.filename = filename - self.loaded_data = None + def __init__(self, parent, selected_id, helptext): + logger.debug("Initializing: %s: (parent, %s, selected_id: %s, helptext: '%s')", + self.__class__.__name__, parent, selected_id, helptext) + super().__init__(parent) + self.pack(side=tk.TOP, padx=5, pady=5, expand=True, fill=tk.X, anchor=tk.N) + + self.session = None # set when loading or clearing from parent self.selected_id = selected_id self.popup_positions = list() - self.scaling_factor = scaling_factor self.add_label() self.tree = ttk.Treeview(self, height=1, selectmode=tk.BROWSE) @@ -171,14 +199,17 @@ def __init__(self, orient="vertical", command=self.tree.yview) self.columns = self.tree_configure(helptext) + logger.debug("Initialized: %s", self.__class__.__name__) def add_label(self): """ Add Treeview Title """ + logger.debug("Adding Treeview title") lbl = ttk.Label(self, text="Session Stats", anchor=tk.CENTER) lbl.pack(side=tk.TOP, expand=True, fill=tk.X, padx=5, pady=5) def tree_configure(self, helptext): """ Build a treeview widget to hold the sessions stats """ + logger.debug("Configuring Treeview") self.tree.configure(yscrollcommand=self.scrollbar.set) self.tree.tag_configure("total", background="black", @@ -191,6 +222,7 @@ def tree_configure(self, helptext): def tree_columns(self): """ Add the columns to the totals treeview """ + logger.debug("Adding Treeview columns") columns = (("session", 40, "#"), ("start", 130, None), ("end", 130, None), @@ -202,6 +234,7 @@ def tree_columns(self): for column in columns: text = column[2] if column[2] else column[0].title() + logger.debug("Adding heading: '%s'", text) self.tree.heading(column[0], text=text) self.tree.column(column[0], width=column[1], @@ -212,19 +245,21 @@ def tree_columns(self): return [column[0] for column in columns] - def tree_insert_data(self, sessions): + def tree_insert_data(self, sessions_summary): """ Insert the data into the totals treeview """ - self.tree.configure(height=len(sessions)) + logger.debug("Inserting treeview data") + self.tree.configure(height=len(sessions_summary)) - for item in sessions: + for item in sessions_summary: values = [item[column] for column in self.columns] - kwargs = {"values": values, "image": Images().icons["graph"]} + kwargs = {"values": values, "image": get_images().icons["graph"]} if values[0] == "Total": kwargs["tags"] = "total" self.tree.insert("", "end", **kwargs) def tree_clear(self): """ Clear the totals tree """ + logger.debug("Clearing treeview data") self.tree.delete(* self.tree.get_children()) self.tree.configure(height=1) @@ -235,17 +270,22 @@ def select_item(self, event): selection = self.tree.focus() values = self.tree.item(selection, "values") if values: + logger.debug("Selected values: %s", values) self.selected_id.set(values[0]) if region == "tree": self.data_popup() def data_popup(self): """ Pop up a window and control it's position """ - toplevel = SessionPopUp(self.loaded_data, self.selected_id.get()) + logger.debug("Popping up data window") + scaling_factor = get_config().scaling_factor + toplevel = SessionPopUp(self.session.modeldir, + self.session.modelname, + self.selected_id.get()) toplevel.title(self.data_popup_title()) position = self.data_popup_get_position() - height = int(720 * self.scaling_factor) - width = int(400 * self.scaling_factor) + height = int(720 * scaling_factor) + width = int(400 * scaling_factor) toplevel.geometry("{}x{}+{}+{}".format(str(height), str(width), str(position[0]), @@ -254,14 +294,17 @@ def data_popup(self): def data_popup_title(self): """ Set the data popup title """ + logger.debug("Setting poup title") selected_id = self.selected_id.get() title = "All Sessions" if selected_id != "Total": - title = "Session #{}".format(selected_id) - return "{} - {}".format(title, self.filename.get()) + title = "{} Model: Session #{}".format(self.session.modelname.title(), selected_id) + logger.debug("Title: '%s'", title) + return "{} - {}".format(title, self.session.modeldir) def data_popup_get_position(self): """ Get the position of the next window """ + logger.debug("getting poup position") init_pos = [120, 120] pos = init_pos while True: @@ -270,25 +313,33 @@ def data_popup_get_position(self): break pos = [item + 200 for item in pos] init_pos, pos = self.data_popup_check_boundaries(init_pos, pos) + logger.debug("Position: %s", pos) return pos def data_popup_check_boundaries(self, initial_position, position): """ Check that the popup remains within the screen boundaries """ + logger.debug("Checking poup boundaries: (initial_position: %s, position: %s)", + initial_position, position) boundary_x = self.winfo_screenwidth() - 120 boundary_y = self.winfo_screenheight() - 120 if position[0] >= boundary_x or position[1] >= boundary_y: initial_position = [initial_position[0] + 50, initial_position[1]] position = initial_position + logger.debug("Returning poup boundaries: (initial_position: %s, position: %s)", + initial_position, position) return initial_position, position class SessionPopUp(tk.Toplevel): - """ Pop up for detailed grap/stats for selected session """ - def __init__(self, data, session_id): - tk.Toplevel.__init__(self) + """ Pop up for detailed graph/stats for selected session """ + def __init__(self, model_dir, model_name, session_id): + logger.debug("Initializing: %s: (model_dir: %s, model_name: %s, session_id: %s)", + self.__class__.__name__, model_dir, model_name, session_id) + super().__init__() - self.is_totals = session_id == "Total" - self.data = self.set_session_data(data, session_id) + self.session_id = session_id + self.session = Session(model_dir=model_dir, model_name=model_name) + self.initialize_session() self.graph = None self.display_data = None @@ -296,25 +347,35 @@ def __init__(self, data, session_id): self.vars = dict() self.graph_initialised = False self.build() - - def set_session_data(self, sessions, session_id): - """ Set the correct list index based on the passed in session is """ - if self.is_totals: - data = SessionsTotals(sessions).stats - else: - data = sessions[int(session_id) - 1] - return data + logger.debug("Initialized: %s", self.__class__.__name__) + + @property + def is_totals(self): + """ Return True if these are totals else False """ + return bool(self.session_id == "Total") + + def initialize_session(self): + """ Initialize the session """ + logger.debug("Initializing session") + kwargs = dict(is_training=False) + if not self.is_totals: + kwargs["session_id"] = int(self.session_id) + logger.debug("Session kwargs: %s", kwargs) + self.session.initialize_session(**kwargs) def build(self): """ Build the popup window """ + logger.debug("Building popup") optsframe, graphframe = self.layout_frames() self.opts_build(optsframe) self.compile_display_data() self.graph_build(graphframe) + logger.debug("Built popup") def layout_frames(self): """ Top level container frames """ + logger.debug("Layout frames") leftframe = ttk.Frame(self) leftframe.pack(side=tk.LEFT, expand=False, fill=tk.BOTH, pady=5) @@ -323,20 +384,25 @@ def layout_frames(self): rightframe = ttk.Frame(self) rightframe.pack(side=tk.RIGHT, fill=tk.BOTH, pady=5, expand=True) + logger.debug("Laid out frames") return leftframe, rightframe def opts_build(self, frame): - """ Options in options to the optsframe """ + """ Build Options into the options frame """ + logger.debug("Building Options") self.opts_combobox(frame) self.opts_checkbuttons(frame) + self.opts_loss_keys(frame) self.opts_entry(frame) self.opts_buttons(frame) sep = ttk.Frame(frame, height=2, relief=tk.RIDGE) sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) + logger.debug("Built Options") def opts_combobox(self, frame): """ Add the options combo boxes """ + logger.debug("Building Combo boxes") choices = {"Display": ("Loss", "Rate"), "Scale": ("Linear", "Log")} @@ -362,9 +428,11 @@ def opts_combobox(self, frame): hlp = self.set_help(item) Tooltip(cmbframe, text=hlp, wraplength=200) + logger.debug("Built Combo boxes") def opts_checkbuttons(self, frame): """ Add the options check buttons """ + logger.debug("Building Check Buttons") for item in ("raw", "trend", "avg", "outliers"): if item == "avg": text = "Show Rolling Average" @@ -384,9 +452,35 @@ def opts_checkbuttons(self, frame): hlp = self.set_help(item) Tooltip(ctl, text=hlp, wraplength=200) + logger.debug("Built Check Buttons") + + def opts_loss_keys(self, frame): + """ Add loss key selections """ + logger.debug("Building Loss Key Check Buttons") + loss_keys = self.session.loss_keys + lk_vars = dict() + for loss_key in sorted(loss_keys): + text = loss_key.replace("_", " ").title() + helptext = "Display {}".format(text) + var = tk.BooleanVar() + var.set(True) + var.trace("w", self.optbtn_reset) + lk_vars[loss_key] = var + + if len(loss_keys) == 1: + # Don't display if there's only one item + break + + ctl = ttk.Checkbutton(frame, variable=var, text=text) + ctl.pack(side=tk.TOP, padx=5, pady=5, anchor=tk.W) + Tooltip(ctl, text=helptext, wraplength=200) + + self.vars["loss_keys"] = lk_vars + logger.debug("Built Loss Key Check Buttons") def opts_entry(self, frame): """ Add the options entry boxes """ + logger.debug("Building Entry Boxes") for item in ("avgiterations", ): if item == "avgiterations": text = "Iterations to Average:" @@ -405,27 +499,32 @@ def opts_entry(self, frame): Tooltip(entframe, text=hlp, wraplength=200) self.vars[item] = ctl + logger.debug("Built Entry Boxes") def opts_buttons(self, frame): """ Add the option buttons """ + logger.debug("Building Buttons") btnframe = ttk.Frame(frame) btnframe.pack(fill=tk.X, pady=5, padx=5, side=tk.BOTTOM) for btntype in ("reset", "save"): cmd = getattr(self, "optbtn_{}".format(btntype)) btn = ttk.Button(btnframe, - image=Images().icons[btntype], + image=get_images().icons[btntype], command=cmd) btn.pack(padx=2, side=tk.RIGHT) hlp = self.set_help(btntype) Tooltip(btn, text=hlp, wraplength=200) + logger.debug("Built Buttons") def optbtn_save(self): """ Action for save button press """ + logger.debug("Saving File") savefile = FileHandler("save", "csv").retfile if not savefile: + logger.debug("Save Cancelled") return - + logger.debug("Saving to: %s", savefile) save_data = self.display_data.stats fieldnames = sorted(key for key in save_data.keys()) @@ -434,16 +533,21 @@ def optbtn_save(self): csvout.writerow(fieldnames) csvout.writerows(zip(*[save_data[key] for key in fieldnames])) - def optbtn_reset(self, *args): + def optbtn_reset(self, *args): # pylint: disable=unused-argument """ Action for reset button press and checkbox changes""" + logger.debug("Refreshing Graph") if not self.graph_initialised: return - self.compile_display_data() + valid = self.compile_display_data() + if not valid: + logger.debug("Invalid data") + return self.graph.refresh(self.display_data, self.vars["display"].get(), self.vars["scale"].get()) + logger.debug("Refreshed Graph") - def graph_scale(self, *args): + def graph_scale(self, *args): # pylint: disable=unused-argument """ Action for changing graph scale """ if not self.graph_initialised: return @@ -477,25 +581,53 @@ def set_help(control): def compile_display_data(self): """ Compile the data to be displayed """ - self.display_data = Calculations(self.data, - self.vars["display"].get(), - self.selections_to_list(), - self.vars["avgiterations"].get(), - self.vars["outliers"].get(), - self.is_totals) + logger.debug("Compiling Display Data") + + loss_keys = [key for key, val in self.vars["loss_keys"].items() + if val.get()] + logger.debug("Selected loss_keys: %s", loss_keys) + + selections = self.selections_to_list() + + if not self.check_valid_selection(loss_keys, selections): + return False + self.display_data = Calculations(session=self.session, + display=self.vars["display"].get(), + loss_keys=loss_keys, + selections=selections, + avg_samples=self.vars["avgiterations"].get(), + flatten_outliers=self.vars["outliers"].get(), + is_totals=self.is_totals) + logger.debug("Compiled Display Data") + return True + + def check_valid_selection(self, loss_keys, selections): + """ Check that there will be data to display """ + display = self.vars["display"].get().lower() + logger.debug("Validating selection. (loss_keys: %s, selections: %s, display: %s)", + loss_keys, selections, display) + if not selections or (display == "loss" and not loss_keys): + msg = "No data to display. Not refreshing" + logger.debug(msg) + print(msg) + return False + return True def selections_to_list(self): """ Compile checkbox selections to list """ + logger.debug("Compiling selections to list") selections = list() for key, val in self.vars.items(): if (isinstance(val, tk.BooleanVar) and key != "outliers" and val.get()): selections.append(key) + logger.debug("Compiling selections to list: %s", selections) return selections def graph_build(self, frame): """ Build the graph in the top right paned window """ + logger.debug("Building Graph") self.graph = SessionGraph(frame, self.display_data, self.vars["display"].get(), @@ -503,3 +635,4 @@ def graph_build(self, frame): self.graph.pack(expand=True, fill=tk.BOTH) self.graph.build() self.graph_initialised = True + logger.debug("Built Graph") diff --git a/lib/gui/display_command.py b/lib/gui/display_command.py index 9a424f48c3..01af8537d3 100644 --- a/lib/gui/display_command.py +++ b/lib/gui/display_command.py @@ -1,6 +1,7 @@ #!/usr/bin python3 """ Command specific tabs of Display Frame of the Faceswap GUI """ import datetime +import logging import os import tkinter as tk @@ -11,19 +12,23 @@ from .display_page import DisplayOptionalPage from .tooltip import Tooltip from .stats import Calculations -from .utils import Images, FileHandler +from .utils import FileHandler, get_config, get_images +logger = logging.getLogger(__name__) # pylint: disable=invalid-name -class PreviewExtract(DisplayOptionalPage): + +class PreviewExtract(DisplayOptionalPage): # pylint: disable=too-many-ancestors """ Tab to display output preview images for extract and convert """ def display_item_set(self): """ Load the latest preview if available """ - Images().load_latest_preview() - self.display_item = Images().previewoutput + logger.trace("Loading latest preview") + get_images().load_latest_preview() + self.display_item = get_images().previewoutput def display_item_process(self): """ Display the preview """ + logger.trace("Displaying preview") if not self.subnotebook.children: self.add_child() else: @@ -31,15 +36,17 @@ def display_item_process(self): def add_child(self): """ Add the preview label child """ + logger.debug("Adding child") preview = self.subnotebook_add_page(self.tabname, widget=None) - lblpreview = ttk.Label(preview, image=Images().previewoutput[1]) + lblpreview = ttk.Label(preview, image=get_images().previewoutput[1]) lblpreview.pack(side=tk.TOP, anchor=tk.NW) Tooltip(lblpreview, text=self.helptext, wraplength=200) def update_child(self): """ Update the preview image on the label """ + logger.trace("Updating preview") for widget in self.subnotebook_get_widgets(): - widget.configure(image=Images().previewoutput[1]) + widget.configure(image=get_images().previewoutput[1]) def save_items(self): """ Open save dialogue and save preview """ @@ -52,41 +59,56 @@ def save_items(self): "{}_{}.{}".format(filename, now, "png")) - Images().previewoutput[0].save(filename) + get_images().previewoutput[0].save(filename) + logger.debug("Saved preview to %s", filename) print("Saved preview to {}".format(filename)) class PreviewTrain(DisplayOptionalPage): # pylint: disable=too-many-ancestors """ Training preview image(s) """ + def __init__(self, *args, **kwargs): + self.update_preview = get_config().tk_vars["updatepreview"] + super().__init__(*args, **kwargs) def display_item_set(self): """ Load the latest preview if available """ - Images().load_training_preview() - self.display_item = Images().previewtrain + logger.trace("Loading latest preview") + if not self.update_preview.get(): + logger.trace("Preview not updated") + return + get_images().load_training_preview() + self.display_item = get_images().previewtrain def display_item_process(self): """ Display the preview(s) resized as appropriate """ - sortednames = sorted([name for name in Images().previewtrain.keys()]) + logger.trace("Displaying preview") + sortednames = sorted(list(get_images().previewtrain.keys())) existing = self.subnotebook_get_titles_ids() + should_update = self.update_preview.get() for name in sortednames: if name not in existing.keys(): self.add_child(name) - else: + elif should_update: tab_id = existing[name] self.update_child(tab_id, name) + if should_update: + self.update_preview.set(False) + def add_child(self, name): """ Add the preview canvas child """ + logger.debug("Adding child") preview = PreviewTrainCanvas(self.subnotebook, name) preview = self.subnotebook_add_page(name, widget=preview) Tooltip(preview, text=self.helptext, wraplength=200) - self.vars["modified"].set(Images().previewtrain[name][2]) + self.vars["modified"].set(get_images().previewtrain[name][2]) def update_child(self, tab_id, name): """ Update the preview canvas """ - if self.vars["modified"].get() != Images().previewtrain[name][2]: - self.vars["modified"].set(Images().previewtrain[name][2]) + logger.debug("Updating preview") + if self.vars["modified"].get() != get_images().previewtrain[name][2]: + self.vars["modified"].set(get_images().previewtrain[name][2]) widget = self.subnotebook_page_from_id(tab_id) widget.reload() @@ -102,11 +124,12 @@ def save_items(self): class PreviewTrainCanvas(ttk.Frame): # pylint: disable=too-many-ancestors """ Canvas to hold a training preview image """ def __init__(self, parent, previewname): + logger.debug("Initializing %s: (previewname: '%s')", self.__class__.__name__, previewname) ttk.Frame.__init__(self, parent) self.name = previewname - Images().resize_image(self.name, None) - self.previewimage = Images().previewtrain[self.name][1] + get_images().resize_image(self.name, None) + self.previewimage = get_images().previewtrain[self.name][1] self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True) @@ -115,18 +138,21 @@ def __init__(self, parent, previewname): image=self.previewimage, anchor=tk.NW) self.bind("", self.resize) + logger.debug("Initialized %s:", self.__class__.__name__) def resize(self, event): """ Resize the image to fit the frame, maintaining aspect ratio """ + logger.trace("Resizing preview image") framesize = (event.width, event.height) # Sometimes image is resized before frame is drawn framesize = None if framesize == (1, 1) else framesize - Images().resize_image(self.name, framesize) + get_images().resize_image(self.name, framesize) self.reload() def reload(self): """ Reload the preview image """ - self.previewimage = Images().previewtrain[self.name][1] + logger.trace("Reloading preview image") + self.previewimage = get_images().previewtrain[self.name][1] self.canvas.itemconfig(self.imgcanvas, image=self.previewimage) def save_preview(self, location): @@ -137,40 +163,63 @@ def save_preview(self, location): "{}_{}.{}".format(filename, now, "png")) - Images().previewtrain[self.name][0].save(filename) + get_images().previewtrain[self.name][0].save(filename) + logger.debug("Saved preview to %s", filename) print("Saved preview to {}".format(filename)) class GraphDisplay(DisplayOptionalPage): # pylint: disable=too-many-ancestors """ The Graph Tab of the Display section """ + def add_options(self): + """ Add the additional options """ + self.add_option_refresh() + super().add_options() + + def add_option_refresh(self): + """ Add refresh button to refresh graph immediately """ + logger.debug("Adding refresh option") + tk_var = get_config().tk_vars["refreshgraph"] + btnrefresh = ttk.Button(self.optsframe, + image=get_images().icons["reset"], + command=lambda: tk_var.set(True)) + btnrefresh.pack(padx=2, side=tk.RIGHT) + Tooltip(btnrefresh, + text="Graph updates every 100 iterations. Click to refresh now.", + wraplength=200) + def display_item_set(self): """ Load the graph(s) if available """ - if self.session.stats["iterations"] == 0: + session = get_config().session + if session.initialized and session.logging_disabled: + logger.trace("Logs disabled. Hiding graph") + self.set_info("Graph is disabled as 'no-logs' has been selected") self.display_item = None + elif session.initialized: + logger.trace("Loading graph") + self.display_item = session else: - self.display_item = self.session.stats + self.display_item = None def display_item_process(self): """ Add a single graph to the graph window """ - losskeys = self.display_item["losskeys"] - loss = self.display_item["loss"] - tabcount = int(len(losskeys) / 2) - existing = self.subnotebook_get_titles_ids() - for i in range(tabcount): - selectedkeys = losskeys[i * 2:(i + 1) * 2] - name = " - ".join(selectedkeys).title().replace("_", " ") - if name not in existing.keys(): - selectedloss = loss[i * 2:(i + 1) * 2] - selection = {"loss": selectedloss, - "losskeys": selectedkeys} - data = Calculations(session=selection, - display="loss", - selections=["raw", "trend"]) - self.add_child(name, data) + logger.trace("Adding graph") + existing = list(self.subnotebook_get_titles_ids().keys()) + + for loss_key in self.display_item.loss_keys: + tabname = loss_key.replace("_", " ").title() + if tabname in existing: + continue + + data = Calculations(session=get_config().session, + display="loss", + loss_keys=[loss_key], + selections=["raw", "trend"]) + self.add_child(tabname, data) def add_child(self, name, data): """ Add the graph for the selected keys """ + logger.debug("Adding child: %s", name) graph = TrainingGraph(self.subnotebook, data, "Loss") graph.build() graph = self.subnotebook_add_page(name, widget=graph) diff --git a/lib/gui/display_graph.py b/lib/gui/display_graph.py index 83665be96d..3abfe8a380 100755 --- a/lib/gui/display_graph.py +++ b/lib/gui/display_graph.py @@ -1,6 +1,7 @@ #!/usr/bin python3 """ Graph functions for Display Frame of the Faceswap GUI """ import datetime +import logging import os import tkinter as tk @@ -8,14 +9,17 @@ from math import ceil, floor import matplotlib +# pylint: disable=wrong-import-position matplotlib.use("TkAgg") -import matplotlib.animation as animation -from matplotlib import pyplot as plt, style + +from matplotlib import pyplot as plt, style # noqa from matplotlib.backends.backend_tkagg import (FigureCanvasTkAgg, - NavigationToolbar2Tk) + NavigationToolbar2Tk) # noqa + +from .tooltip import Tooltip # noqa +from .utils import get_config, get_images # noqa -from .tooltip import Tooltip -from .utils import Images +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ancestors @@ -26,28 +30,24 @@ class NavigationToolbar(NavigationToolbar2Tk): # pylint: disable=too-many-ances t[0] in ("Home", "Pan", "Zoom", "Save")] @staticmethod - def _Button(frame, text, file, command, extension=".gif"): + def _Button(frame, text, file, command, extension=".gif"): # pylint: disable=arguments-differ """ Map Buttons to their own frame. - Use custom button icons, - Use ttk buttons - pack to the right """ + Use custom button icons, Use ttk buttons pack to the right """ iconmapping = {"home": "reset", "filesave": "save", "zoom_to_rect": "zoom"} icon = iconmapping[file] if iconmapping.get(file, None) else file - img = Images().icons[icon] + img = get_images().icons[icon] btn = ttk.Button(frame, text=text, image=img, command=command) btn.pack(side=tk.RIGHT, padx=2) return btn def _init_toolbar(self): - """ Same as original but ttk widgets and standard - tooltips used. Separator added and message label - packed to the left """ + """ Same as original but ttk widgets and standard tooltips used. Separator added and + message label packed to the left """ xmin, xmax = self.canvas.figure.bbox.intervalx height, width = 50, xmax-xmin - ttk.Frame.__init__(self, master=self.window, - width=int(width), height=int(height)) + ttk.Frame.__init__(self, master=self.window, width=int(width), height=int(height)) sep = ttk.Frame(self, height=2, relief=tk.RIDGE) sep.pack(fill=tk.X, pady=(5, 0), side=tk.TOP) @@ -76,14 +76,14 @@ def _init_toolbar(self): class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors """ Base class for matplotlib line graphs """ def __init__(self, parent, data, ylabel): - ttk.Frame.__init__(self, parent) + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(parent) style.use("ggplot") self.calcs = data self.ylabel = ylabel - self.colourmaps = ["Reds", "Blues", "Greens", - "Purples", "Oranges", "Greys", - "copper", "summer", "bone"] + self.colourmaps = ["Reds", "Blues", "Greens", "Purples", "Oranges", + "Greys", "copper", "summer", "bone"] self.lines = list() self.toolbar = None self.fig = plt.figure(figsize=(4, 4), dpi=75) @@ -92,26 +92,29 @@ def __init__(self, parent, data, ylabel): self.initiate_graph() self.update_plot(initiate=True) + logger.debug("Initialized %s", self.__class__.__name__) def initiate_graph(self): """ Place the graph canvas """ - self.plotcanvas.get_tk_widget().pack(side=tk.TOP, - padx=5, - fill=tk.BOTH, - expand=True) + logger.debug("Setting plotcanvas") + self.plotcanvas.get_tk_widget().pack(side=tk.TOP, padx=5, fill=tk.BOTH, expand=True) plt.subplots_adjust(left=0.100, bottom=0.100, right=0.95, top=0.95, wspace=0.2, hspace=0.2) + logger.debug("Set plotcanvas") def update_plot(self, initiate=True): """ Update the plot with incoming data """ + logger.trace("Updating plot") if initiate: + logger.debug("Initializing plot") self.lines = list() self.ax1.clear() self.axes_labels_set() + logger.debug("Initialized plot") fulldata = [item for item in self.calcs.stats.values()] self.axes_limits_set(fulldata) @@ -120,37 +123,37 @@ def update_plot(self, initiate=True): keys = list(self.calcs.stats.keys()) for idx, item in enumerate(self.lines_sort(keys)): if initiate: - self.lines.extend(self.ax1.plot(xrng, - self.calcs.stats[item[0]], - label=item[1], - linewidth=item[2], - color=item[3])) + self.lines.extend(self.ax1.plot(xrng, self.calcs.stats[item[0]], + label=item[1], linewidth=item[2], color=item[3])) else: self.lines[idx].set_data(xrng, self.calcs.stats[item[0]]) if initiate: self.legend_place() + logger.trace("Updated plot") def axes_labels_set(self): """ Set the axes label and range """ + logger.debug("Setting axes labels. y-label: '%s'", self.ylabel) self.ax1.set_xlabel("Iterations") self.ax1.set_ylabel(self.ylabel) def axes_limits_set_default(self): """ Set default axes limits """ + logger.debug("Setting default axes ranges") self.ax1.set_ylim(0.00, 100.0) self.ax1.set_xlim(0, 1) def axes_limits_set(self, data): """ Set the axes limits """ xmax = self.calcs.iterations - 1 if self.calcs.iterations > 1 else 1 - if data: ymin, ymax = self.axes_data_get_min_max(data) self.ax1.set_ylim(ymin, ymax) self.ax1.set_xlim(0, xmax) else: self.axes_limits_set_default() + logger.trace("axes ranges: (y: (%s, %s), x:(0, %s)", ymin, ymax, xmax) @staticmethod def axes_data_get_min_max(data): @@ -164,15 +167,18 @@ def axes_data_get_min_max(data): ymax.append(max(dataset) * 1000) ymin = floor(min(ymin)) / 1000 ymax = ceil(max(ymax)) / 1000 + logger.trace("ymin: %s, ymax: %s", ymin, ymax) return ymin, ymax def axes_set_yscale(self, scale): """ Set the Y-Scale to log or linear """ + logger.debug("yscale: '%s'", scale) self.ax1.set_yscale(scale) def lines_sort(self, keys): """ Sort the data keys into consistent order - and set line colourmap and line width """ + and set line color map and line width """ + logger.trace("Sorting lines") raw_lines = list() sorted_lines = list() for key in sorted(keys): @@ -184,29 +190,28 @@ def lines_sort(self, keys): groupsize = self.lines_groupsize(raw_lines, sorted_lines) sorted_lines = raw_lines + sorted_lines - lines = self.lines_style(sorted_lines, groupsize) return lines @staticmethod def lines_groupsize(raw_lines, sorted_lines): """ Get the number of items in each group. - If raw data isn't selected, then check - the length of remaining groups until - something is found """ + If raw data isn't selected, then check the length of + remaining groups until something is found """ groupsize = 1 if raw_lines: groupsize = len(raw_lines) else: for check in ("avg", "trend"): if any(item[0].startswith(check) for item in sorted_lines): - groupsize = len([item for item in sorted_lines - if item[0].startswith(check)]) + groupsize = len([item for item in sorted_lines if item[0].startswith(check)]) break + logger.trace(groupsize) return groupsize def lines_style(self, lines, groupsize): - """ Set the colourmap and linewidth for each group """ + """ Set the color map and line width for each group """ + logger.trace("Setting lines style") groups = int(len(lines) / groupsize) colours = self.lines_create_colors(groupsize, groups) for idx, item in enumerate(lines): @@ -215,21 +220,24 @@ def lines_style(self, lines, groupsize): return lines def lines_create_colors(self, groupsize, groups): - """ Create the colours """ + """ Create the colors """ colours = list() for i in range(1, groups + 1): for colour in self.colourmaps[0:groupsize]: cmap = matplotlib.cm.get_cmap(colour) cpoint = 1 - (i / 5) colours.append(cmap(cpoint)) + logger.trace(colours) return colours def legend_place(self): """ Place and format legend """ + logger.debug("Placing legend") self.ax1.legend(loc="upper right", ncol=2) def toolbar_place(self, parent): """ Add Graph Navigation toolbar """ + logger.debug("Placing toolbar") self.toolbar = NavigationToolbar(self.plotcanvas, parent) self.toolbar.pack(side=tk.BOTTOM) self.toolbar.update() @@ -240,72 +248,48 @@ class TrainingGraph(GraphBase): # pylint: disable=too-many-ancestors def __init__(self, parent, data, ylabel): GraphBase.__init__(self, parent, data, ylabel) + self.add_callback() - self.anim = None + def add_callback(self): + """ Add the variable trace to update graph on recent button or save iteration """ + get_config().tk_vars["refreshgraph"].trace("w", self.refresh) def build(self): - """ Update the plot area with loss values and cycle through to - animate """ - self.anim = animation.FuncAnimation(self.fig, - self.animate, - interval=200, - blit=False) + """ Update the plot area with loss values """ + logger.debug("Building training graph") self.plotcanvas.draw() + logger.debug("Built training graph") - def animate(self, i): + def refresh(self, *args): # pylint: disable=unused-argument """ Read loss data and apply to graph """ + logger.debug("Updating plot") self.calcs.refresh() self.update_plot(initiate=False) - - def set_animation_rate(self, iterations): - """ Change the animation update interval based on how - many iterations have been - There's no point calculating a graph over thousands of - points of data when the change will be miniscule """ - if iterations > 30000: - speed = 60000 # 1 min updates - elif iterations > 20000: - speed = 30000 # 30 sec updates - elif iterations > 10000: - speed = 10000 # 10 sec updates - elif iterations > 5000: - speed = 5000 # 5 sec updates - elif iterations > 1000: - speed = 2000 # 2 sec updates - elif iterations > 500: - speed = 1000 # 1 sec updates - elif iterations > 100: - speed = 500 # 0.5 sec updates - else: - speed = 200 # 200ms updates - if not self.anim.event_source.interval == speed: - self.anim.event_source.interval = speed + self.plotcanvas.draw() + get_config().tk_vars["refreshgraph"].set(False) def save_fig(self, location): """ Save the figure to file """ - keys = sorted([key.replace("raw_", "") - for key in self.calcs.stats.keys() + logger.debug("Saving graph: '%s'", location) + keys = sorted([key.replace("raw_", "") for key in self.calcs.stats.keys() if key.startswith("raw_")]) filename = " - ".join(keys) now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - filename = os.path.join(location, - "{}_{}.{}".format(filename, - now, - "png")) + filename = os.path.join(location, "{}_{}.{}".format(filename, now, "png")) self.fig.set_size_inches(16, 9) self.fig.savefig(filename, bbox_inches="tight", dpi=120) print("Saved graph to {}".format(filename)) + logger.debug("Saved graph: '%s'", filename) self.resize_fig() def resize_fig(self): """ Resize the figure back to the canvas """ - class Event(): - """ Event class that needs to be passed to - plotcanvas.resize """ + class Event(): # pylint: disable=too-few-public-methods + """ Event class that needs to be passed to plotcanvas.resize """ pass Event.width = self.winfo_width() Event.height = self.winfo_height() - self.plotcanvas.resize(Event) + self.plotcanvas.resize(Event) # pylint: disable=no-value-for-parameter class SessionGraph(GraphBase): # pylint: disable=too-many-ancestors @@ -316,18 +300,24 @@ def __init__(self, parent, data, ylabel, scale): def build(self): """ Build the session graph """ + logger.debug("Building session graph") self.toolbar_place(self) self.plotcanvas.draw() + logger.debug("Built session graph") def refresh(self, data, ylabel, scale): """ Refresh graph data """ + logger.debug("Refreshing session graph: (ylabel: '%s', scale: '%s')", ylabel, scale) self.calcs = data self.ylabel = ylabel self.set_yscale_type(scale) + logger.debug("Refreshed session graph") def set_yscale_type(self, scale): """ switch the y-scale and redraw """ + logger.debug("Updating scale type: '%s'", scale) self.scale = scale self.update_plot(initiate=True) self.axes_set_yscale(self.scale) self.plotcanvas.draw() + logger.debug("Updated scale type") diff --git a/lib/gui/display_page.py b/lib/gui/display_page.py index 8f840e20c7..a12b95115a 100644 --- a/lib/gui/display_page.py +++ b/lib/gui/display_page.py @@ -1,21 +1,25 @@ #!/usr/bin python3 """ Display Page parent classes for display section of the Faceswap GUI """ +import logging import tkinter as tk from tkinter import ttk from .tooltip import Tooltip -from .utils import Images +from .utils import get_images + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class DisplayPage(ttk.Frame): """ Parent frame holder for each tab. Defines uniform structure for each tab to inherit from """ def __init__(self, parent, tabname, helptext): + logger.debug("Initializing %s: (tabname: '%s', helptext: %s", + self.__class__.__name__, tabname, helptext) ttk.Frame.__init__(self, parent) self.pack(fill=tk.BOTH, side=tk.TOP, anchor=tk.NW) - self.session = parent.session self.runningtask = parent.runningtask self.helptext = helptext self.tabname = tabname @@ -30,32 +34,37 @@ def __init__(self, parent, tabname, helptext): self.add_frame_separator() self.set_mainframe_single_tab_style() parent.add(self, text=self.tabname.title()) + logger.debug("Initialized %s", self.__class__.__name__,) def add_optional_vars(self, varsdict): """ Add page specific variables """ if isinstance(varsdict, dict): for key, val in varsdict.items(): + logger.debug("Adding: (%s: %s)", key, val) self.vars[key] = val @staticmethod def set_vars(): - """ Overide to return a dict of page specific variables """ + """ Override to return a dict of page specific variables """ return dict() def add_subnotebook(self): """ Add the main frame notebook """ + logger.debug("Adding subnotebook") notebook = ttk.Notebook(self) notebook.pack(side=tk.TOP, anchor=tk.NW, fill=tk.BOTH, expand=True) return notebook def add_options_frame(self): """ Add the display tab options """ + logger.debug("Adding options frame") optsframe = ttk.Frame(self) optsframe.pack(side=tk.BOTTOM, padx=5, pady=5, fill=tk.X) return optsframe def add_options_info(self): """ Add the info bar """ + logger.debug("Adding options info") lblinfo = ttk.Label(self.optsframe, textvariable=self.vars["info"], anchor=tk.W, @@ -64,22 +73,26 @@ def add_options_info(self): def set_info(self, msg): """ Set the info message """ + logger.debug("Setting info: %s", msg) self.vars["info"].set(msg) def add_frame_separator(self): """ Add a separator between top and bottom frames """ + logger.debug("Adding frame seperator") sep = ttk.Frame(self, height=2, relief=tk.RIDGE) sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) @staticmethod def set_mainframe_single_tab_style(): """ Configure ttk notebook style to represent a single frame """ + logger.debug("Setting main frame single tab style") nbstyle = ttk.Style() nbstyle.configure("single.TNotebook", borderwidth=0) nbstyle.layout("single.TNotebook.Tab", []) def subnotebook_add_page(self, tabtitle, widget=None): """ Add a page to the sub notebook """ + logger.debug("Adding subnotebook page: %s", tabtitle) frame = widget if widget else ttk.Frame(self.subnotebook) frame.pack(padx=5, pady=5, fill=tk.BOTH, expand=True) self.subnotebook.add(frame, text=tabtitle) @@ -89,28 +102,32 @@ def subnotebook_add_page(self, tabtitle, widget=None): def subnotebook_configure(self): """ Configure notebook to display or hide tabs """ if len(self.subnotebook.children) == 1: + logger.debug("Setting single page style") self.subnotebook.configure(style="single.TNotebook") else: + logger.debug("Setting multi page style") self.subnotebook.configure(style="TNotebook") def subnotebook_hide(self): """ Hide the subnotebook. Used for hiding Optional displays """ - if self.subnotebook.winfo_ismapped(): + if self.subnotebook and self.subnotebook.winfo_ismapped(): + logger.debug("Hiding subnotebook") self.subnotebook.pack_forget() + self.subnotebook.destroy() + self.subnotebook = None def subnotebook_show(self): """ Show subnotebook. Used for displaying Optional displays """ - if not self.subnotebook.winfo_ismapped(): - self.subnotebook.pack(side=tk.TOP, - anchor=tk.NW, - fill=tk.BOTH, - expand=True) + if not self.subnotebook: + logger.debug("Showing subnotebook") + self.subnotebook = self.add_subnotebook() def subnotebook_get_widgets(self): """ Return each widget that sits within each subnotebook frame """ + logger.debug("Getting subnotebook widgets") for child in self.subnotebook.winfo_children(): for widget in child.winfo_children(): yield widget @@ -120,11 +137,13 @@ def subnotebook_get_titles_ids(self): tabs = dict() for tab_id in range(0, self.subnotebook.index("end")): tabs[self.subnotebook.tab(tab_id, "text")] = tab_id + logger.debug(tabs) return tabs def subnotebook_page_from_id(self, tab_id): """ Return subnotebook tab widget from it's ID """ tab_name = self.subnotebook.tabs()[tab_id].split(".")[-1] + logger.debug(tab_name) return self.subnotebook.children[tab_name] @@ -155,19 +174,23 @@ def set_vars(): modified = tk.DoubleVar() modified.set(None) - return {"enabled": enabled, - "ready": ready, - "modified": modified} + tk_vars = {"enabled": enabled, + "ready": ready, + "modified": modified} + logger.debug(tk_vars) + return tk_vars # INFO LABEL def set_info_text(self): """ Set waiting for display text """ if not self.vars["enabled"].get(): - self.set_info("{} disabled".format(self.tabname.title())) + msg = "{} disabled".format(self.tabname.title()) elif self.vars["enabled"].get() and not self.vars["ready"].get(): - self.set_info("Waiting for {}...".format(self.tabname)) + msg = "Waiting for {}...".format(self.tabname) else: - self.set_info("Displaying {}".format(self.tabname)) + msg = "Displaying {}".format(self.tabname) + logger.debug(msg) + self.set_info(msg) # DISPLAY OPTIONS BAR def add_options(self): @@ -177,8 +200,9 @@ def add_options(self): def add_option_save(self): """ Add save button to save page output to file """ + logger.debug("Adding save option") btnsave = ttk.Button(self.optsframe, - image=Images().icons["save"], + image=get_images().icons["save"], command=self.save_items) btnsave.pack(padx=2, side=tk.RIGHT) Tooltip(btnsave, @@ -187,6 +211,7 @@ def add_option_save(self): def add_option_enable(self): """ Add checkbutton to enable/disable page """ + logger.debug("Adding enable option") chkenable = ttk.Checkbutton(self.optsframe, variable=self.vars["enabled"], text="Enable {}".format(self.tabname), @@ -202,6 +227,7 @@ def save_items(self): def on_chkenable_change(self): """ Update the display immediately on a checkbutton change """ + logger.debug("Enabled checkbox changed") if self.vars["enabled"].get(): self.subnotebook_show() else: @@ -213,6 +239,7 @@ def update_page(self, waittime): if not self.runningtask.get(): return if self.vars["enabled"].get(): + logger.trace("Updating page") self.display_item_set() self.load_display() self.after(waittime, lambda t=waittime: self.update_page(t)) @@ -225,6 +252,7 @@ def load_display(self): """ Load the display """ if not self.display_item: return + logger.debug("Loading display") self.display_item_process() self.vars["ready"].set(True) self.set_info_text() diff --git a/lib/gui/menu.py b/lib/gui/menu.py new file mode 100644 index 0000000000..d135161b44 --- /dev/null +++ b/lib/gui/menu.py @@ -0,0 +1,134 @@ +#!/usr/bin python3 +""" The Menu Bars for faceswap GUI """ + +import logging +import os +import sys +import tkinter as tk + +from importlib import import_module + +from lib.Serializer import JSONSerializer + +from .utils import get_config +from .popup_configure import popup_config + + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class MainMenuBar(tk.Menu): + """ GUI Main Menu Bar """ + def __init__(self, master=None): + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__(master) + self.root = master + self.config = get_config() + + self.file_menu = tk.Menu(self, tearoff=0) + self.recent_menu = tk.Menu(self.file_menu, tearoff=0, postcommand=self.refresh_recent_menu) + self.edit_menu = tk.Menu(self, tearoff=0) + + self.build_file_menu() + self.build_edit_menu() + logger.debug("Initialized %s", self.__class__.__name__) + + def build_file_menu(self): + """ Add the file menu to the menu bar """ + logger.debug("Building File menu") + self.file_menu.add_command( + label="Load full config...", underline=0, command=self.config.load) + self.file_menu.add_command( + label="Save full config...", underline=0, command=self.config.save) + self.file_menu.add_separator() + self.file_menu.add_cascade(label="Open recent", underline=6, menu=self.recent_menu) + self.file_menu.add_separator() + self.file_menu.add_command( + label="Reset all to default", underline=0, command=self.config.cli_opts.reset) + self.file_menu.add_command( + label="Clear all", underline=0, command=self.config.cli_opts.clear) + self.file_menu.add_separator() + self.file_menu.add_command(label="Quit", underline=0, command=self.root.close_app) + self.add_cascade(label="File", menu=self.file_menu, underline=0) + logger.debug("Built File menu") + + def build_recent_menu(self): + """ Load recent files into menu bar """ + logger.debug("Building Recent Files menu") + serializer = JSONSerializer + menu_file = os.path.join(self.config.pathcache, ".recent.json") + if not os.path.isfile(menu_file): + self.clear_recent_files(serializer, menu_file) + with open(menu_file, "rb") as inp: + recent_files = serializer.unmarshal(inp.read().decode("utf-8")) + logger.debug("Loaded recent files: %s", recent_files) + for recent_item in recent_files: + filename, command = recent_item + logger.debug("processing: ('%s', %s)", filename, command) + if not os.path.isfile(filename): + logger.debug("File does not exist") + continue + lbl_command = command if command else "All" + self.recent_menu.add_command( + label="{} ({})".format(filename, lbl_command.title()), + command=lambda fnm=filename, cmd=command: self.config.load(cmd, fnm)) + self.recent_menu.add_separator() + self.recent_menu.add_command( + label="Clear recent files", + underline=0, + command=lambda srl=serializer, mnu=menu_file: self.clear_recent_files(srl, mnu)) + + logger.debug("Built Recent Files menu") + + @staticmethod + def clear_recent_files(serializer, menu_file): + """ Creates or clears recent file list """ + logger.debug("clearing recent files list: '%s'", menu_file) + recent_files = serializer.marshal(list()) + with open(menu_file, "wb") as out: + out.write(recent_files.encode("utf-8")) + + def refresh_recent_menu(self): + """ Refresh recent menu on save/load of files """ + self.recent_menu.delete(0, "end") + self.build_recent_menu() + + def build_edit_menu(self): + """ Add the edit menu to the menu bar """ + logger.debug("Building Edit menu") + edit_menu = tk.Menu(self, tearoff=0) + + configs = self.scan_for_configs() + for name in sorted(list(configs.keys())): + label = "Configure {} Plugins...".format(name.title()) + config = configs[name] + edit_menu.add_command( + label=label, + underline=10, + command=lambda conf=(name, config), root=self.root: popup_config(conf, root)) + self.add_cascade(label="Edit", menu=edit_menu, underline=0) + logger.debug("Built Edit menu") + + def scan_for_configs(self): + """ Scan for config.ini file locations """ + root_path = os.path.abspath(os.path.dirname(sys.argv[0])) + plugins_path = os.path.join(root_path, "plugins") + logger.debug("Scanning path: '%s'", plugins_path) + configs = dict() + for dirpath, _, filenames in os.walk(plugins_path): + if "_config.py" in filenames: + plugin_type = os.path.split(dirpath)[-1] + config = self.load_config(plugin_type) + configs[plugin_type] = config + logger.debug("Configs loaded: %s", sorted(list(configs.keys()))) + return configs + + @staticmethod + def load_config(plugin_type): + """ Load the config to generate config file if it doesn't exist and get filename """ + # Load config to generate default if doesn't exist + mod = ".".join(("plugins", plugin_type, "_config")) + module = import_module(mod) + config = module.Config(None) + logger.debug("Found '%s' config at '%s'", plugin_type, config.configfile) + return config diff --git a/lib/gui/options.py b/lib/gui/options.py index d39c12f308..32eb40e309 100644 --- a/lib/gui/options.py +++ b/lib/gui/options.py @@ -1,14 +1,13 @@ #!/usr/bin python3 -""" Cli Options and Config functions for the GUI """ +""" Cli Options for the GUI """ import inspect from argparse import SUPPRESS import logging from tkinter import ttk from lib import cli -from lib.Serializer import JSONSerializer import tools.cli as ToolsCli -from .utils import FileHandler, Images +from .utils import get_images logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -93,8 +92,7 @@ def process_options(self, command_options): logger.trace("Skipping suppressed option: %s", opt) continue ctl, sysbrowser, filetypes, action_option = self.set_control(opt) - opt["control_title"] = self.set_control_title( - opt.get("opts", "")) + opt["control_title"] = self.set_control_title(opt.get("opts", "")) opt["control"] = ctl opt["filesystem_browser"] = sysbrowser opt["filetypes"] = filetypes @@ -126,6 +124,8 @@ def set_control(self, option): sysbrowser, filetypes = self.set_sysbrowser(action, filetypes, action_option) + elif option.get("min_max", None): + ctl = ttk.Scale elif option.get("choices", "") != "": ctl = ttk.Combobox elif option.get("action", "") == "store_true": @@ -226,7 +226,7 @@ def gen_cli_arguments(self, command): optval = str(option.get("value", "").get()) opt = option["opts"][0] if command in ("extract", "convert") and opt == "-o": - Images().pathoutput = optval + get_images().pathoutput = optval if optval in ("False", ""): continue elif optval == "True": @@ -238,59 +238,3 @@ def gen_cli_arguments(self, command): else: opt = (opt, optval) yield opt - - -class Config(): - """ Actions for loading and saving Faceswap GUI command configurations """ - - def __init__(self, cli_opts, tk_vars): - logger.debug("Initializing %s", self.__class__.__name__) - self.cli_opts = cli_opts - self.serializer = JSONSerializer - self.tk_vars = tk_vars - logger.debug("Initialized %s", self.__class__.__name__) - - def load(self, command=None): - """ Load a saved config file """ - logger.debug("Loading config: (command: '%s')", command) - cfgfile = FileHandler("open", "config").retfile - if not cfgfile: - return - cfg = self.serializer.unmarshal(cfgfile.read()) - opts = self.get_command_options(cfg, command) if command else cfg - for cmd, opts in opts.items(): - self.set_command_args(cmd, opts) - logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile) - - def get_command_options(self, cfg, command): - """ return the saved options for the requested - command, if not loading global options """ - opts = cfg.get(command, None) - if not opts: - self.tk_vars["consoleclear"].set(True) - print("No {} section found in file".format(command)) - logger.info("No %s section found in file", command) - retval = {command: opts} - logger.debug(retval) - return retval - - def set_command_args(self, command, options): - """ Pass the saved config items back to the CliOptions """ - if not options: - return - for srcopt, srcval in options.items(): - optvar = self.cli_opts.get_one_option_variable(command, srcopt) - if not optvar: - continue - optvar.set(srcval) - - def save(self, command=None): - """ Save the current GUI state to a config file in json format """ - logger.debug("Saving config: (command: '%s')", command) - cfgfile = FileHandler("save", "config").retfile - if not cfgfile: - return - cfg = self.cli_opts.get_option_values(command) - cfgfile.write(self.serializer.marshal(cfg)) - cfgfile.close() - logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile) diff --git a/lib/gui/popup_configure.py b/lib/gui/popup_configure.py new file mode 100644 index 0000000000..c2331f85ef --- /dev/null +++ b/lib/gui/popup_configure.py @@ -0,0 +1,348 @@ +#!/usr/bin python3 +""" Configure Plugins popup of the Faceswap GUI """ + +from configparser import ConfigParser +import logging +import tkinter as tk + +from tkinter import ttk + +from .tooltip import Tooltip +from .utils import get_config, ContextMenu, set_slider_rounding + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name +POPUP = dict() + + +def popup_config(config, root): + """ Close any open popup and open requested popup """ + if POPUP: + p_key = list(POPUP.keys())[0] + logger.debug("Closing open popup: '%s'", p_key) + POPUP[p_key].destroy() + del POPUP[p_key] + window = ConfigurePlugins(config, root) + POPUP[config[0]] = window + + +class ConfigurePlugins(tk.Toplevel): + """ Pop up for detailed graph/stats for selected session """ + def __init__(self, config, root): + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__() + name, self.config = config + self.title("{} Plugins".format(name.title())) + self.set_geometry(root) + + self.page_frame = ttk.Frame(self) + self.page_frame.pack(fill=tk.BOTH, expand=True) + + self.plugin_info = dict() + self.config_dict_gui = self.get_config() + self.build() + self.update() + logger.debug("Initialized %s", self.__class__.__name__) + + def set_geometry(self, root): + """ Set pop-up geometry """ + scaling_factor = get_config().scaling_factor + pos_x = root.winfo_x() + 80 + pos_y = root.winfo_y() + 80 + width = int(720 * scaling_factor) + height = int(400 * scaling_factor) + logger.debug("Pop up Geometry: %sx%s, %s+%s", width, height, pos_x, pos_y) + self.geometry("{}x{}+{}+{}".format(width, height, pos_x, pos_y)) + + def get_config(self): + """ Format config into useful format for GUI and pull default value if a value has not + been supplied """ + logger.debug("Formatting Config for GUI") + conf = dict() + for section in self.config.config.sections(): + self.config.section = section + category = section.split(".")[0] + options = self.config.defaults[section] + conf.setdefault(category, dict())[section] = options + for key in options.keys(): + if key == "helptext": + self.plugin_info[section] = options[key] + continue + options[key]["value"] = self.config.config_dict.get(key, options[key]["default"]) + logger.debug("Formatted Config for GUI: %s", conf) + return conf + + def build(self): + """ Build the config popup """ + logger.debug("Building plugin config popup") + container = ttk.Notebook(self.page_frame) + container.pack(fill=tk.BOTH, expand=True) + categories = sorted(list(key for key in self.config_dict_gui.keys())) + if "global" in categories: # Move global to first item + categories.insert(0, categories.pop(categories.index("global"))) + for category in categories: + page = self.build_page(container, category) + container.add(page, text=category.title()) + + self.add_frame_separator() + self.add_actions() + logger.debug("Built plugin config popup") + + def build_page(self, container, category): + """ Build a plugin config page """ + logger.debug("Building plugin config page: '%s'", category) + plugins = sorted(list(key for key in self.config_dict_gui[category].keys())) + if any(plugin != category for plugin in plugins): + page = ttk.Notebook(container) + page.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + for plugin in plugins: + frame = ConfigFrame(page, + self.config_dict_gui[category][plugin], + self.plugin_info[plugin]) + title = plugin[plugin.rfind(".") + 1:] + title = title.replace("_", " ").title() + page.add(frame, text=title) + else: + page = ConfigFrame(container, + self.config_dict_gui[category][plugins[0]], + self.plugin_info[plugins[0]]) + + logger.debug("Built plugin config page: '%s'", category) + + return page + + def add_frame_separator(self): + """ Add a separator between top and bottom frames """ + logger.debug("Add frame seperator") + sep = ttk.Frame(self.page_frame, height=2, relief=tk.RIDGE) + sep.pack(fill=tk.X, pady=(5, 0), side=tk.BOTTOM) + logger.debug("Added frame seperator") + + def add_actions(self): + """ Add Action buttons """ + logger.debug("Add action buttons") + frame = ttk.Frame(self.page_frame) + frame.pack(fill=tk.BOTH, padx=5, pady=5, side=tk.BOTTOM) + btn_cls = ttk.Button(frame, text="Cancel", width=10, command=self.destroy) + btn_cls.pack(padx=2, side=tk.RIGHT) + btn_ok = ttk.Button(frame, text="OK", width=10, command=self.save_config) + btn_ok.pack(padx=2, side=tk.RIGHT) + logger.debug("Added action buttons") + + def save_config(self): + """ Save the config file """ + logger.debug("Saving config") + options = {sect: opts + for value in self.config_dict_gui.values() + for sect, opts in value.items()} + + new_config = ConfigParser(allow_no_value=True) + for section, items in self.config.defaults.items(): + logger.debug("Adding section: '%s')", section) + self.config.insert_config_section(section, items["helptext"], config=new_config) + for item, def_opt in items.items(): + if item == "helptext": + continue + new_opt = options[section][item] + logger.debug("Adding option: (item: '%s', default: '%s' new: '%s'", + item, def_opt, new_opt) + helptext = def_opt["helptext"] + helptext += self.config.set_helptext_choices(def_opt) + helptext += "\n[Default: {}]".format(def_opt["default"]) + helptext = self.config.format_help(helptext, is_section=False) + new_config.set(section, helptext) + new_config.set(section, item, str(new_opt["selected"].get())) + self.config.config = new_config + self.config.save_config() + print("Saved config: '{}'".format(self.config.configfile)) + self.destroy() + logger.debug("Saved config") + + +class ConfigFrame(ttk.Frame): # pylint: disable=too-many-ancestors + """ Config Frame - Holds the Options for config """ + + def __init__(self, parent, options, plugin_info): + logger.debug("Initializing %s", self.__class__.__name__) + ttk.Frame.__init__(self, parent) + self.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + + self.options = options + self.plugin_info = plugin_info + + self.canvas = tk.Canvas(self, bd=0, highlightthickness=0) + self.canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + self.optsframe = ttk.Frame(self.canvas) + self.optscanvas = self.canvas.create_window((0, 0), window=self.optsframe, anchor=tk.NW) + + self.build_frame() + logger.debug("Initialized %s", self.__class__.__name__) + + def build_frame(self): + """ Build the options frame for this command """ + logger.debug("Add Config Frame") + self.add_scrollbar() + self.canvas.bind("", self.resize_frame) + + self.add_info() + for key, val in self.options.items(): + if key == "helptext": + continue + OptionControl(key, val, self.optsframe) + logger.debug("Added Config Frame") + + def add_scrollbar(self): + """ Add a scrollbar to the options frame """ + logger.debug("Add Config Scrollbar") + scrollbar = ttk.Scrollbar(self, command=self.canvas.yview) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.canvas.config(yscrollcommand=scrollbar.set) + self.optsframe.bind("", self.update_scrollbar) + logger.debug("Added Config Scrollbar") + + def update_scrollbar(self, event): # pylint: disable=unused-argument + """ Update the options frame scrollbar """ + self.canvas.configure(scrollregion=self.canvas.bbox("all")) + + def resize_frame(self, event): + """ Resize the options frame to fit the canvas """ + logger.debug("Resize Config Frame") + canvas_width = event.width + self.canvas.itemconfig(self.optscanvas, width=canvas_width) + logger.debug("Resized Config Frame") + + def add_info(self): + """ Plugin information """ + info_frame = ttk.Frame(self.optsframe) + info_frame.pack(fill=tk.X, expand=True) + lbl = ttk.Label(info_frame, text="About:", width=20, anchor=tk.W) + lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) + info = ttk.Label(info_frame, text=self.plugin_info) + info.pack(padx=5, pady=5, fill=tk.X, expand=True) + + +class OptionControl(): + """ Build the correct control for the option parsed and place it on the + frame """ + + def __init__(self, title, values, option_frame): + logger.debug("Initializing %s", self.__class__.__name__) + self.title = title + self.values = values + self.option_frame = option_frame + + self.control = self.set_control() + self.control_frame = self.set_control_frame() + self.tk_var = self.set_tk_var() + + self.build_full_control() + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def helptext(self): + """ Format the help text for tooltips """ + logger.debug("Format control help: '%s'", self.title) + helptext = self.values.get("helptext", "") + helptext = helptext.replace("\n\t", "\n - ").replace("%%", "%") + helptext = self.title + " - " + helptext + logger.debug("Formatted control help: (title: '%s', help: '%s'", self.title, helptext) + return helptext + + def set_control(self): + """ Set the correct control type for this option """ + dtype = self.values["type"] + choices = self.values["choices"] + if choices: + control = ttk.Combobox + elif dtype == bool: + control = ttk.Checkbutton + elif dtype in (int, float): + control = ttk.Scale + else: + control = ttk.Entry + logger.debug("Setting control '%s' to %s", self.title, control) + return control + + def set_control_frame(self): + """ Frame to hold control and it's label """ + logger.debug("Build config control frame") + frame = ttk.Frame(self.option_frame) + frame.pack(fill=tk.X, expand=True) + logger.debug("Built confog control frame") + return frame + + def set_tk_var(self): + """ Correct variable type for control """ + logger.debug("Setting config variable type: '%s'", self.title) + var = tk.BooleanVar if self.control == ttk.Checkbutton else tk.StringVar + var = var(self.control_frame) + logger.debug("Set config variable type: ('%s': %s", self.title, type(var)) + return var + + def build_full_control(self): + """ Build the correct control type for the option passed through """ + logger.debug("Build confog option control") + self.build_control_label() + self.build_one_control() + self.values["selected"] = self.tk_var + logger.debug("Built option control") + + def build_control_label(self): + """ Label for control """ + logger.debug("Build config control label: '%s'", self.title) + title = self.title.replace("_", " ").title() + lbl = ttk.Label(self.control_frame, text=title, width=20, anchor=tk.W) + lbl.pack(padx=5, pady=5, side=tk.LEFT, anchor=tk.N) + logger.debug("Built config control label: '%s'", self.title) + + def build_one_control(self): + """ Build and place the option controls """ + logger.debug("Build control: (title: '%s', values: %s)", self.title, self.values) + self.tk_var.set(self.values["value"]) + + if self.control == ttk.Scale: + self.slider_control() + else: + self.control_to_optionsframe() + logger.debug("Built control: '%s'", self.title) + + def slider_control(self): + """ A slider control with corresponding Entry box """ + logger.debug("Add slider control to Config Options Frame: %s", self.control) + d_type = self.values["type"] + rnd = self.values["rounding"] + min_max = self.values["min_max"] + + tbox = ttk.Entry(self.control_frame, width=8, textvariable=self.tk_var, justify=tk.RIGHT) + tbox.pack(padx=(0, 5), side=tk.RIGHT) + ctl = self.control( + self.control_frame, + variable=self.tk_var, + command=lambda val, var=self.tk_var, dt=d_type, rn=rnd, mm=min_max: + set_slider_rounding(val, var, dt, rn, mm)) + ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) + rc_menu = ContextMenu(ctl) + rc_menu.cm_bind() + ctl["from_"] = min_max[0] + ctl["to"] = min_max[1] + + Tooltip(ctl, text=self.helptext, wraplength=720) + Tooltip(tbox, text=self.helptext, wraplength=720) + logger.debug("Added slider control to Options Frame: %s", self.control) + + def control_to_optionsframe(self): + """ Standard non-check buttons sit in the main options frame """ + logger.debug("Add control to Options Frame: %s", self.control) + choices = self.values["choices"] + if self.control == ttk.Checkbutton: + ctl = self.control(self.control_frame, variable=self.tk_var, text=None) + else: + ctl = self.control(self.control_frame, textvariable=self.tk_var) + ctl.pack(padx=5, pady=5, fill=tk.X, expand=True) + rc_menu = ContextMenu(ctl) + rc_menu.cm_bind() + if choices: + logger.debug("Adding combo choices: %s", choices) + ctl["values"] = [choice for choice in choices] + Tooltip(ctl, text=self.helptext, wraplength=720) + logger.debug("Added control to Options Frame: %s", self.control) diff --git a/lib/gui/stats.py b/lib/gui/stats.py index 781e9c4f45..b792a90cb8 100644 --- a/lib/gui/stats.py +++ b/lib/gui/stats.py @@ -9,14 +9,14 @@ from math import ceil, sqrt import numpy as np - -from lib.Serializer import PickleSerializer +import tensorflow as tf +from lib.Serializer import JSONSerializer logger = logging.getLogger(__name__) # pylint: disable=invalid-name def convert_time(timestamp): - """ Convert time stamp to total hours, mins and second """ + """ Convert time stamp to total hours, minutes and seconds """ hrs = int(timestamp // 3600) if hrs < 10: hrs = "{0:02d}".format(hrs) @@ -25,164 +25,279 @@ def convert_time(timestamp): return hrs, mins, secs -class SavedSessions(): - """ Saved Training Session """ - def __init__(self, sessions_data): - self.serializer = PickleSerializer - self.sessions = self.load_sessions(sessions_data) - - def load_sessions(self, filename): - """ Load previously saved sessions """ - stats = list() - if os.path.isfile(filename): - with open(filename, self.serializer.roptions) as sessions: - stats = self.serializer.unmarshal(sessions.read()) - return stats - - def save_sessions(self, filename): - """ Save the session file """ - with open(filename, self.serializer.woptions) as session: - session.write(self.serializer.marshal(self.sessions)) - logger.info("Saved session stats to: %s", filename) - - -class CurrentSession(): - """ The current training session """ - def __init__(self): - self.stats = {"iterations": 0, - "batchsize": None, # Set and reset by wrapper - "timestamps": [], - "loss": [], - "losskeys": []} - self.timestats = {"start": None, - "elapsed": None} - self.modeldir = None # Set and reset by wrapper - self.filename = None - self.historical = None - - def initialise_session(self, currentloss): - """ Initialise the training session """ - self.load_historical() - for item in currentloss: - self.stats["losskeys"].append(item[0]) - self.stats["loss"].append(list()) - self.timestats["start"] = time.time() - - def load_historical(self): - """ Load historical data and add current session to the end """ - self.filename = os.path.join(self.modeldir, "trainingstats.fss") - self.historical = SavedSessions(self.filename) - self.historical.sessions.append(self.stats) - - def add_loss(self, currentloss): - """ Add a loss item from the training process """ - if self.stats["iterations"] == 0: - self.initialise_session(currentloss) - - self.stats["iterations"] += 1 - self.add_timestats() - - for idx, item in enumerate(currentloss): - self.stats["loss"][idx].append(float(item[1])) - - def add_timestats(self): - """ Add timestats to loss dict and timestats """ - now = time.time() - self.stats["timestamps"].append(now) - elapsed_time = now - self.timestats["start"] - hrs, mins, secs = convert_time(elapsed_time) - self.timestats["elapsed"] = "{}:{}:{}".format(hrs, mins, secs) - - def save_session(self): - """ Save the session file to the modeldir """ - if self.stats["iterations"] > 0: - logger.info("Saving session stats...") - self.historical.save_sessions(self.filename) - - -class SessionsTotals(): - """ The compiled totals of all saved sessions """ - def __init__(self, all_sessions): - self.stats = {"split": [], - "iterations": 0, - "batchsize": [], - "timestamps": [], - "loss": [], - "losskeys": []} - - self.initiate(all_sessions) - self.compile(all_sessions) - - def initiate(self, sessions): - """ Initiate correct losskey titles and number of loss lists """ - for losskey in sessions[0]["losskeys"]: - self.stats["losskeys"].append(losskey) - self.stats["loss"].append(list()) - - def compile(self, sessions): - """ Compile all of the sessions into totals """ - current_split = 0 - for session in sessions: - iterations = session["iterations"] - current_split += iterations - self.stats["split"].append(current_split) - self.stats["iterations"] += iterations - self.stats["timestamps"].extend(session["timestamps"]) - self.stats["batchsize"].append(session["batchsize"]) - self.add_loss(session["loss"]) - - def add_loss(self, session_loss): - """ Add loss vals to each of their respective lists """ - for idx, loss in enumerate(session_loss): - self.stats["loss"][idx].extend(loss) +class TensorBoardLogs(): + """ Parse and return data from TensorBoard logs """ + def __init__(self, logs_folder): + self.folder_base = logs_folder + self.log_filenames = self.set_log_filenames() + + def set_log_filenames(self): + """ Set the TensorBoard log filenames for all existing sessions """ + logger.debug("Loading log filenames. base_dir: '%s'", self.folder_base) + log_filenames = dict() + for dirpath, _, filenames in os.walk(self.folder_base): + if not any(filename.startswith("events.out.tfevents") for filename in filenames): + continue + logfiles = [filename for filename in filenames + if filename.startswith("events.out.tfevents")] + # Take the last logfile, in case of previous crash + logfile = os.path.join(dirpath, sorted(logfiles)[-1]) + side, session = os.path.split(dirpath) + side = os.path.split(side)[1] + session = int(session[session.rfind("_") + 1:]) + log_filenames.setdefault(session, dict())[side] = logfile + logger.debug("logfiles: %s", log_filenames) + return log_filenames + + def get_loss(self, side=None, session=None): + """ Read the loss from the TensorBoard logs + Specify a side or a session or leave at None for all + """ + logger.debug("Getting loss: (side: %s, session: %s)", side, session) + all_loss = dict() + for sess, sides in self.log_filenames.items(): + if session is not None and sess != session: + logger.debug("Skipping session: %s", sess) + continue + loss = dict() + for sde, logfile in sides.items(): + if side is not None and sde != side: + logger.debug("Skipping side: %s", sde) + continue + for event in tf.train.summary_iterator(logfile): + for summary in event.summary.value: + if "loss" not in summary.tag: + continue + tag = summary.tag.replace("batch_", "") + loss.setdefault(tag, + dict()).setdefault(sde, + list()).append(summary.simple_value) + all_loss[sess] = loss + return all_loss + + def get_timestamps(self, session=None): + """ Read the timestamps from the TensorBoard logs + Specify a session or leave at None for all + NB: For all intents and purposes timestamps are the same for + both sides, so just read from one side """ + logger.debug("Getting timestamps") + all_timestamps = dict() + for sess, sides in self.log_filenames.items(): + if session is not None and sess != session: + logger.debug("Skipping sessions: %s", sess) + continue + for logfile in sides.values(): + timestamps = [event.wall_time + for event in tf.train.summary_iterator(logfile)] + logger.debug("Total timestamps for session %s: %s", sess, len(timestamps)) + all_timestamps[sess] = timestamps + break # break after first file read + return all_timestamps + + +class Session(): + """ The Loaded or current training session """ + def __init__(self, model_dir=None, model_name=None): + logger.debug("Initializing %s", self.__class__.__name__) + self.serializer = JSONSerializer + self.state = None + self.modeldir = model_dir # Set and reset by wrapper for training sessions + self.modelname = model_name # Set and reset by wrapper for training sessions + self.tb_logs = None + self.initialized = False + self.session_id = None # Set to specific session_id or current training session + self.summary = SessionsSummary(self) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def batchsize(self): + """ Return the session batchsize """ + return self.session["batchsize"] + + @property + def config(self): + """ Return config and other information """ + retval = {key: val for key, val in self.state["config"]} + retval["training_size"] = self.state["training_size"] + retval["input_size"] = [val[0] for key, val in self.state["inputs"].items() + if key.startswith("face")][0] + return retval + + @property + def full_summary(self): + """ Retun all sessions summary data""" + return self.summary.compile_stats() + + @property + def iterations(self): + """ Return session iterations """ + return self.session["iterations"] + + @property + def logging_disabled(self): + """ Return whether logging is disabled for this session """ + return self.session["no_logs"] + + @property + def loss(self): + """ Return loss from logs for current session """ + loss_dict = self.tb_logs.get_loss(session=self.session_id)[self.session_id] + return loss_dict + + @property + def loss_keys(self): + """ Return list of unique session loss keys """ + if self.session_id is None: + loss_keys = self.total_loss_keys + else: + loss_keys = set(loss_key for side_keys in self.session["loss_names"].values() + for loss_key in side_keys) + return list(loss_keys) + + @property + def lowest_loss(self): + """ Return the lowest average loss per save iteration seen """ + return self.state["lowest_avg_loss"] + + @property + def session(self): + """ Return current session dictionary """ + return self.state["sessions"][str(self.session_id)] + + @property + def session_ids(self): + """ Return sorted list of all existing session ids in the state file """ + return sorted([int(key) for key in self.state["sessions"].keys()]) + + @property + def timestamps(self): + """ Return timestamps from logs for current session """ + ts_dict = self.tb_logs.get_timestamps(session=self.session_id) + return ts_dict[self.session_id] + + @property + def total_batchsize(self): + """ Return all session batch sizes """ + return {int(sess_id): sess["batchsize"] + for sess_id, sess in self.state["sessions"].items()} + + @property + def total_iterations(self): + """ Return session iterations """ + return self.state["iterations"] + + @property + def total_loss(self): + """ Return collated loss for all session """ + loss_dict = dict() + for sess in self.tb_logs.get_loss().values(): + for loss_key, side_loss in sess.items(): + for side, loss in side_loss.items(): + loss_dict.setdefault(loss_key, dict()).setdefault(side, list()).extend(loss) + return loss_dict + + @property + def total_loss_keys(self): + """ Return list of unique session loss keys across all sessions """ + loss_keys = set(loss_key + for session in self.state["sessions"].values() + for loss_keys in session["loss_names"].values() + for loss_key in loss_keys) + return list(loss_keys) + + @property + def total_timestamps(self): + """ Return timestamps from logs seperated per session for all sessions """ + return self.tb_logs.get_timestamps() + + def initialize_session(self, is_training=False, session_id=None): + """ Initialize the training session """ + logger.debug("Initializing session: (is_training: %s, session_id: %s)", + is_training, session_id) + self.load_state_file() + self.tb_logs = TensorBoardLogs(os.path.join(self.modeldir, + "{}_logs".format(self.modelname))) + if is_training: + self.session_id = max(int(key) for key in self.state["sessions"].keys()) + else: + self.session_id = session_id + self.initialized = True + logger.debug("Initialized session") + + def load_state_file(self): + """ Load the current state file """ + state_file = os.path.join(self.modeldir, "{}_state.json".format(self.modelname)) + logger.debug("Loading State: '%s'", state_file) + try: + with open(state_file, "rb") as inp: + state = self.serializer.unmarshal(inp.read().decode("utf-8")) + self.state = state + logger.debug("Loaded state: %s", state) + except IOError as err: + logger.warning("Unable to load state file. Graphing disabled: %s", str(err)) class SessionsSummary(): """ Calculations for analysis summary stats """ - def __init__(self, raw_data): - self.summary = list() - self.summary_stats_compile(raw_data) - - def summary_stats_compile(self, raw_data): - """ Compile summary stats """ - raw_summaries = list() - for idx, session in enumerate(raw_data): - raw_summaries.append(self.summarise_session(idx, session)) - - totals_summary = self.summarise_totals(raw_summaries) - raw_summaries.append(totals_summary) - self.format_summaries(raw_summaries) - - # Compile Session Summaries - @staticmethod - def summarise_session(idx, session): - """ Compile stats for session passed in """ - starttime = session["timestamps"][0] - endtime = session["timestamps"][-1] - elapsed = endtime - starttime - # Bump elapsed to 0.1s if no time is recorded - # to hack around div by zero error - elapsed = 0.1 if elapsed == 0 else elapsed - rate = (session["batchsize"] * session["iterations"]) / elapsed - return {"session": idx + 1, - "start": starttime, - "end": endtime, - "elapsed": elapsed, - "rate": rate, - "batch": session["batchsize"], - "iterations": session["iterations"]} + def __init__(self, session): + logger.debug("Initializing %s: (session: %s)", self.__class__.__name__, session) + self.session = session + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def iterations(self): + """ Return session iterations sizes """ + return {int(sess_id): sess["iterations"] + for sess_id, sess in self.session.state["sessions"].items()} + + @property + def time_stats(self): + """ Return session time stats """ + ts_data = self.session.tb_logs.get_timestamps() + time_stats = {sess_id: {"start_time": min(timestamps), + "end_time": max(timestamps)} + for sess_id, timestamps in ts_data.items()} + return time_stats + + @property + def sessions_stats(self): + """ Return compiled stats """ + compiled = list() + for sess_idx, ts_data in self.time_stats.items(): + elapsed = ts_data["end_time"] - ts_data["start_time"] + batchsize = self.session.total_batchsize[sess_idx] + iterations = self.iterations[sess_idx] + compiled.append({"session": sess_idx, + "start": ts_data["start_time"], + "end": ts_data["end_time"], + "elapsed": elapsed, + "rate": (batchsize * iterations) / elapsed, + "batch": batchsize, + "iterations": iterations}) + return compiled + + def compile_stats(self): + """ Compile sessions stats with totals, format and return """ + logger.debug("Compiling sessions summary data") + compiled_stats = self.sessions_stats + logger.debug("sessions_stats: %s", compiled_stats) + total_stats = self.total_stats(compiled_stats) + compiled_stats.append(total_stats) + compiled_stats = self.format_stats(compiled_stats) + logger.debug("Final stats: %s", compiled_stats) + return compiled_stats @staticmethod - def summarise_totals(raw_summaries): - """ Compile the stats for all sessions combined """ + def total_stats(sessions_stats): + """ Return total stats """ + logger.debug("Compiling Totals") elapsed = 0 rate = 0 batchset = set() iterations = 0 - total_summaries = len(raw_summaries) - - for idx, summary in enumerate(raw_summaries): + total_summaries = len(sessions_stats) + for idx, summary in enumerate(sessions_stats): if idx == 0: starttime = summary["start"] if idx == total_summaries - 1: @@ -192,150 +307,170 @@ def summarise_totals(raw_summaries): batchset.add(summary["batch"]) iterations += summary["iterations"] batch = ",".join(str(bs) for bs in batchset) + totals = {"session": "Total", + "start": starttime, + "end": endtime, + "elapsed": elapsed, + "rate": rate / total_summaries, + "batch": batch, + "iterations": iterations} + logger.debug(totals) + return totals - return {"session": "Total", - "start": starttime, - "end": endtime, - "elapsed": elapsed, - "rate": rate / total_summaries, - "batch": batch, - "iterations": iterations} - - def format_summaries(self, raw_summaries): - """ Format the summaries nicely for display """ - for summary in raw_summaries: - summary["start"] = time.strftime("%x %X", - time.gmtime(summary["start"])) - summary["end"] = time.strftime("%x %X", - time.gmtime(summary["end"])) + @staticmethod + def format_stats(compiled_stats): + """ Format for display """ + logger.debug("Formatting stats") + for summary in compiled_stats: hrs, mins, secs = convert_time(summary["elapsed"]) + summary["start"] = time.strftime("%x %X", time.gmtime(summary["start"])) + summary["end"] = time.strftime("%x %X", time.gmtime(summary["end"])) summary["elapsed"] = "{}:{}:{}".format(hrs, mins, secs) summary["rate"] = "{0:.1f}".format(summary["rate"]) - self.summary = raw_summaries + return compiled_stats class Calculations(): - """ Class to hold calculations against raw session data """ - def __init__(self, - session, - display="loss", - selections=["raw"], - avg_samples=10, - flatten_outliers=False, - is_totals=False): + """ Class to pull raw data for given session(s) and perform calculations """ + def __init__(self, session, display="loss", loss_keys=["loss"], selections=["raw"], + avg_samples=10, flatten_outliers=False, is_totals=False): + logger.debug("Initializing %s: (session: %s, display: %s, loss_keys: %s, selections: %s, " + "avg_samples: %s, flatten_outliers: %s, is_totals: %s", + self.__class__.__name__, session, display, loss_keys, selections, avg_samples, + flatten_outliers, is_totals) warnings.simplefilter("ignore", np.RankWarning) self.session = session - if display.lower() == "loss": - display = self.session["losskeys"] - else: - display = [display] - self.args = {"display": display, - "selections": selections, - "avg_samples": int(avg_samples), - "flatten_outliers": flatten_outliers, - "is_totals": is_totals} + self.display = display + self.loss_keys = loss_keys + self.selections = selections + self.is_totals = is_totals + self.args = {"avg_samples": int(avg_samples), + "flatten_outliers": flatten_outliers} self.iterations = 0 self.stats = None self.refresh() + logger.debug("Initialized %s", self.__class__.__name__) def refresh(self): """ Refresh the stats """ + logger.debug("Refreshing") + if not self.session.initialized: + logger.warning("Session data is not initialized. Not refreshing") + return self.iterations = 0 self.stats = self.get_raw() self.get_calculations() self.remove_raw() + logger.debug("Refreshed") def get_raw(self): """ Add raw data to stats dict """ - raw = dict() - for idx, item in enumerate(self.args["display"]): - if item.lower() == "rate": - data = self.calc_rate(self.session) - else: - data = self.session["loss"][idx][:] + logger.debug("Getting Raw Data") + raw = dict() + iterations = set() + if self.display.lower() == "loss": + loss_dict = self.session.total_loss if self.is_totals else self.session.loss + for loss_name, side_loss in loss_dict.items(): + if loss_name not in self.loss_keys: + continue + for side, loss in side_loss.items(): + if self.args["flatten_outliers"]: + loss = self.flatten_outliers(loss) + iterations.add(len(loss)) + raw["raw_{}_{}".format(loss_name, side)] = loss + + self.iterations = 0 if not iterations else min(iterations) + if len(iterations) > 1: + # Crop all losses to the same number of items + if self.iterations == 0: + raw = {lossname: list() for lossname in raw.keys()} + else: + raw = {lossname: loss[:self.iterations] for lossname, loss in raw} + + else: # Rate calulation + data = self.calc_rate_total() if self.is_totals else self.calc_rate() if self.args["flatten_outliers"]: data = self.flatten_outliers(data) + self.iterations = len(data) + raw = {"raw_rate": data} - if self.iterations == 0: - self.iterations = len(data) - - raw["raw_{}".format(item)] = data + logger.debug("Got Raw Data") return raw def remove_raw(self): """ Remove raw values from stats if not requested """ - if "raw" in self.args["selections"]: + if "raw" in self.selections: return + logger.debug("Removing Raw Data from output") for key in list(self.stats.keys()): if key.startswith("raw"): del self.stats[key] + logger.debug("Removed Raw Data from output") + + def calc_rate(self): + """ Calculate rate per iteration """ + logger.debug("Calculating rate") + batchsize = self.session.batchsize + timestamps = self.session.timestamps + iterations = range(len(timestamps) - 1) + rate = [batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations] + logger.debug("Calculated rate: Item_count: %s", len(rate)) + return rate - def calc_rate(self, data): + def calc_rate_total(self): """ Calculate rate per iteration NB: For totals, gaps between sessions can be large - so time diffeence has to be reset for each session's + so time difference has to be reset for each session's rate calculation """ - batchsize = data["batchsize"] - if self.args["is_totals"]: - split = data["split"] - else: - batchsize = [batchsize] - split = [len(data["timestamps"])] - - prev_split = 0 + logger.debug("Calculating totals rate") + batchsizes = self.session.total_batchsize + total_timestamps = self.session.total_timestamps rate = list() - - for idx, current_split in enumerate(split): - prev_time = data["timestamps"][prev_split] - timestamp_chunk = data["timestamps"][prev_split:current_split] - for item in timestamp_chunk: - current_time = item - timediff = current_time - prev_time - iter_rate = 0 if timediff == 0 else batchsize[idx] / timediff - rate.append(iter_rate) - prev_time = current_time - prev_split = current_split - - if self.args["flatten_outliers"]: - rate = self.flatten_outliers(rate) + for sess_id in sorted(total_timestamps.keys()): + batchsize = batchsizes[sess_id] + timestamps = total_timestamps[sess_id] + iterations = range(len(timestamps) - 1) + rate.extend([batchsize / (timestamps[i + 1] - timestamps[i]) for i in iterations]) + logger.debug("Calculated totals rate: Item_count: %s", len(rate)) return rate @staticmethod def flatten_outliers(data): """ Remove the outliers from a provided list """ + logger.debug("Flattening outliers") retdata = list() samples = len(data) mean = (sum(data) / samples) limit = sqrt(sum([(item - mean)**2 for item in data]) / samples) + logger.debug("samples: %s, mean: %s, limit: %s", samples, mean, limit) - for item in data: + for idx, item in enumerate(data): if (mean - limit) <= item <= (mean + limit): retdata.append(item) else: + logger.debug("Item idx: %s, value: %s flattened to %s", idx, item, mean) retdata.append(mean) + logger.debug("Flattened outliers") return retdata def get_calculations(self): """ Perform the required calculations """ - for selection in self.get_selections(): - if selection[0] == "raw": + for selection in self.selections: + if selection == "raw": continue - method = getattr(self, "calc_{}".format(selection[0])) - key = "{}_{}".format(selection[0], selection[1]) - raw = self.stats["raw_{}".format(selection[1])] - self.stats[key] = method(raw) - - def get_selections(self): - """ Compile a list of data to be calculated """ - for summary in self.args["selections"]: - for item in self.args["display"]: - yield summary, item + logger.debug("Calculating: %s", selection) + method = getattr(self, "calc_{}".format(selection)) + raw_keys = [key for key in self.stats.keys() if key.startswith("raw_")] + for key in raw_keys: + selected_key = "{}_{}".format(selection, key.replace("raw_", "")) + self.stats[selected_key] = method(self.stats[key]) def calc_avg(self, data): """ Calculate rolling average """ + logger.debug("Calculating Average") avgs = list() presample = ceil(self.args["avg_samples"] / 2) postsample = self.args["avg_samples"] - presample @@ -353,11 +488,13 @@ def calc_avg(self, data): avg = sum(data[idx - presample:idx + postsample]) \ / self.args["avg_samples"] avgs.append(avg) + logger.debug("Calculated Average") return avgs @staticmethod def calc_trend(data): """ Compile trend data """ + logger.debug("Calculating Trend") points = len(data) if points < 10: dummy = [None for i in range(points)] @@ -366,4 +503,5 @@ def calc_trend(data): fit = np.polyfit(x_range, data, 3) poly = np.poly1d(fit) trend = poly(x_range) + logger.debug("Calculated Trend") return trend diff --git a/lib/gui/tooltip.py b/lib/gui/tooltip.py index 6e2782cabe..d89e8eb58c 100755 --- a/lib/gui/tooltip.py +++ b/lib/gui/tooltip.py @@ -42,7 +42,7 @@ def __init__(self, widget, waittime=400, wraplength=250): - self.waittime = waittime # in miliseconds, originally 500 + self.waittime = waittime # in milliseconds, originally 500 self.wraplength = wraplength # in pixels, originally 180 self.widget = widget self.text = text @@ -115,7 +115,7 @@ def tip_pos_calculator(widget, label, # No further checks will be done. # TIP: - # A further mod might automagically augment the + # A further mod might auto-magically augment the # wraplength when the tooltip is too high to be # kept inside the screen. y_1 = 0 diff --git a/lib/gui/utils.py b/lib/gui/utils.py index 3b75737e56..325ecb61e2 100644 --- a/lib/gui/utils.py +++ b/lib/gui/utils.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 """ Utility functions for the GUI """ import logging - import os import platform import sys @@ -10,24 +9,50 @@ from tkinter import filedialog, ttk from PIL import Image, ImageTk +from lib.Serializer import JSONSerializer + logger = logging.getLogger(__name__) # pylint: disable=invalid-name +_CONFIG = None +_IMAGES = None + + +def initialize_config(cli_opts, scaling_factor, pathcache, statusbar, session): + """ Initialize the config and add to global constant """ + global _CONFIG # pylint: disable=global-statement + if _CONFIG is not None: + return + logger.debug("Initializing config: (cli_opts: %s, tk_vars: %s, pathcache: %s, statusbar: %s, " + "session: %s)", cli_opts, scaling_factor, pathcache, statusbar, session) + _CONFIG = Config(cli_opts, scaling_factor, pathcache, statusbar, session) + +def get_config(): + """ return the _CONFIG constant """ + return _CONFIG -class Singleton(type): - """ Instigate a singleton. - From: https://stackoverflow.com/questions/6760685 - Singletons are often frowned upon. - Feel free to instigate a better solution """ +def initialize_images(): + """ Initialize the config and add to global constant """ + global _IMAGES # pylint: disable=global-statement + if _IMAGES is not None: + return + logger.debug("Initializing images") + _IMAGES = Images() + + +def get_images(): + """ return the _CONFIG constant """ + return _IMAGES - _instances = {} - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, - cls).__call__(*args, - **kwargs) - return cls._instances[cls] +def set_slider_rounding(value, var, d_type, round_to, min_max): + """ Set the underlying variable to correct number based on slider rounding """ + if d_type == float: + var.set(round(float(value), round_to)) + else: + steps = range(min_max[0], min_max[1] + round_to, round_to) + value = min(steps, key=lambda x: abs(x - int(float(value)))) + var.set(value) class FileHandler(): @@ -52,8 +77,8 @@ def __init__(self, handletype, filetype, command=None, action=None, ("PNG", "*.png"), ("TIFF", "*.tif", "*.tiff"), all_files), + "state": (("State files", "*.json"), all_files), "log": (("Log files", "*.log"), all_files), - "session": (("Faceswap session files", "*.fss"), all_files), "video": (("Audio Video Interleave", "*.avi"), ("Flash Video", "*.flv"), ("Matroska", "*.mkv"), @@ -164,11 +189,15 @@ def nothing(): # pylint: disable=useless-return return -class Images(metaclass=Singleton): - """ Holds locations of images and actual images """ +class Images(): + """ Holds locations of images and actual images - def __init__(self, pathcache=None): - logger.debug("Initializing %s: (pathcache: '%s'", self.__class__.__name__, pathcache) + Don't call directly. Call get_images() + """ + + def __init__(self): + logger.debug("Initializing %s", self.__class__.__name__) + pathcache = get_config().pathcache self.pathicons = os.path.join(pathcache, "icons") self.pathpreview = os.path.join(pathcache, "preview") self.pathoutput = None @@ -194,7 +223,7 @@ def delete_preview(self): """ Delete the preview files """ logger.debug("Deleting previews") for item in os.listdir(self.pathpreview): - if item.startswith(".gui_preview_") and item.endswith(".jpg"): + if item.startswith(".gui_training_preview") and item.endswith(".jpg"): fullitem = os.path.join(self.pathpreview, item) logger.debug("Deleting: '%s'", fullitem) os.remove(fullitem) @@ -210,34 +239,34 @@ def clear_image_cache(self): @staticmethod def get_images(imgpath): """ Get the images stored within the given directory """ - logger.debug("Getting images: '%s'", imgpath) + logger.trace("Getting images: '%s'", imgpath) if not os.path.isdir(imgpath): logger.debug("Folder does not exist") return None files = [os.path.join(imgpath, f) for f in os.listdir(imgpath) if f.endswith((".png", ".jpg"))] - logger.debug("Image files: %s", files) + logger.trace("Image files: %s", files) return files def load_latest_preview(self): """ Load the latest preview image for extract and convert """ - logger.debug("Loading preview image") + logger.trace("Loading preview image") imagefiles = self.get_images(self.pathoutput) if not imagefiles or len(imagefiles) == 1: logger.debug("No preview to display") self.previewoutput = None return - # Get penultimate file so we don't accidently + # Get penultimate file so we don't accidentally # load a file that is being saved show_file = sorted(imagefiles, key=os.path.getctime)[-2] img = Image.open(show_file) img.thumbnail((768, 432)) - logger.debug("Displaying preview: '%s'", show_file) + logger.trace("Displaying preview: '%s'", show_file) self.previewoutput = (img, ImageTk.PhotoImage(img)) def load_training_preview(self): """ Load the training preview images """ - logger.debug("Loading Training preview images") + logger.trace("Loading Training preview images") imagefiles = self.get_images(self.pathpreview) modified = None if not imagefiles: @@ -250,7 +279,7 @@ def load_training_preview(self): name = os.path.splitext(name)[0] name = name[name.rfind("_") + 1:].title() try: - logger.debug("Displaying preview: '%s'", img) + logger.trace("Displaying preview: '%s'", img) size = self.get_current_size(name) self.previewtrain[name] = [Image.open(img), None, modified] self.resize_image(name, size) @@ -270,20 +299,20 @@ def load_training_preview(self): def get_current_size(self, name): """ Return the size of the currently displayed image """ - logger.debug("Getting size: '%s'", name) + logger.trace("Getting size: '%s'", name) if not self.previewtrain.get(name, None): return None img = self.previewtrain[name][1] if not img: return None - logger.debug("Got size: (name: '%s', width: '%s', height: '%s')", + logger.trace("Got size: (name: '%s', width: '%s', height: '%s')", name, img.width(), img.height()) return img.width(), img.height() def resize_image(self, name, framesize): """ Resize the training preview image based on the passed in frame size """ - logger.debug("Resizing image: (name: '%s', framesize: %s", name, framesize) + logger.trace("Resizing image: (name: '%s', framesize: %s", name, framesize) displayimg = self.previewtrain[name][0] if framesize: frameratio = float(framesize[0]) / float(framesize[1]) @@ -295,7 +324,7 @@ def resize_image(self, name, framesize): else: scale = framesize[1] / float(displayimg.size[1]) size = (int(displayimg.size[0] * scale), framesize[1]) - logger.debug("Scaling: (scale: %s, size: %s", scale, size) + logger.trace("Scaling: (scale: %s, size: %s", scale, size) # Hacky fix to force a reload if it happens to find corrupted # data, probably due to reading the image whilst it is partially @@ -335,7 +364,9 @@ def cm_bind(self): """ Bind the menu to the widget's Right Click event """ button = "" if platform.system() == "Darwin" else "" logger.debug("Binding '%s' to '%s'", button, self.widget.winfo_class()) - self.widget.bind(button, lambda event: self.tk_popup(event.x_root, event.y_root, 0)) + x_offset = int(34 * get_config().scaling_factor) + self.widget.bind(button, + lambda event: self.tk_popup(event.x_root + x_offset, event.y_root, 0)) def select_all(self): """ Select all for Text or Entry widgets """ @@ -351,16 +382,16 @@ def select_all(self): class ConsoleOut(ttk.Frame): # pylint: disable=too-many-ancestors """ The Console out section of the GUI """ - def __init__(self, parent, debug, tk_vars): - logger.debug("Initializing %s: (parent: %s, debug: %s, tk_vars: %s)", - self.__class__.__name__, parent, debug, tk_vars) + def __init__(self, parent, debug): + logger.debug("Initializing %s: (parent: %s, debug: %s)", + self.__class__.__name__, parent, debug) ttk.Frame.__init__(self, parent) self.pack(side=tk.TOP, anchor=tk.W, padx=10, pady=(2, 0), fill=tk.BOTH, expand=True) self.console = tk.Text(self) rc_menu = ContextMenu(self.console) rc_menu.cm_bind() - self.console_clear = tk_vars['consoleclear'] + self.console_clear = get_config().tk_vars['consoleclear'] self.set_console_clear_var_trace() self.debug = debug self.build_console() @@ -395,7 +426,7 @@ def redirect_console(self): sys.stderr = SysOutRouter(console=self.console, out_type="stderr") logger.debug("Redirected console") - def clear(self, *args): + def clear(self, *args): # pylint: disable=unused-argument """ Clear the console output screen """ logger.debug("Clear console") if not self.console_clear.get(): @@ -427,3 +458,146 @@ def write(self, string): def flush(): """ If flush is forced, send it to normal terminal """ sys.__stdout__.flush() + + +class Config(): + """ Global configuration settings + + Don't call directly. Call get_config() + """ + + def __init__(self, cli_opts, scaling_factor, pathcache, statusbar, session): + logger.debug("Initializing %s: (cli_opts: %s, scaling_factor: %s, pathcache: %s, " + "statusbar: %s, session: %s)", self.__class__.__name__, cli_opts, + scaling_factor, pathcache, statusbar, session) + self.cli_opts = cli_opts + self.scaling_factor = scaling_factor + self.pathcache = pathcache + self.statusbar = statusbar + self.serializer = JSONSerializer + self.tk_vars = self.set_tk_vars() + self.command_notebook = None # set in command.py + self.session = session + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def command_tabs(self): + """ Return dict of command tab titles with their IDs """ + return {self.command_notebook.tab(tab_id, "text").lower(): tab_id + for tab_id in range(0, self.command_notebook.index("end"))} + + @staticmethod + def set_tk_vars(): + """ TK Variables to be triggered by to indicate + what state various parts of the GUI should be in """ + display = tk.StringVar() + display.set(None) + + runningtask = tk.BooleanVar() + runningtask.set(False) + + actioncommand = tk.StringVar() + actioncommand.set(None) + + generatecommand = tk.StringVar() + generatecommand.set(None) + + consoleclear = tk.BooleanVar() + consoleclear.set(False) + + refreshgraph = tk.BooleanVar() + refreshgraph.set(False) + + updatepreview = tk.BooleanVar() + updatepreview.set(False) + + tk_vars = {"display": display, + "runningtask": runningtask, + "action": actioncommand, + "generate": generatecommand, + "consoleclear": consoleclear, + "refreshgraph": refreshgraph, + "updatepreview": updatepreview} + logger.debug(tk_vars) + return tk_vars + + def load(self, command=None, filename=None): + """ Pop up load dialog for a saved config file """ + logger.debug("Loading config: (command: '%s')", command) + if filename: + with open(filename, "r") as cfgfile: + cfg = self.serializer.unmarshal(cfgfile.read()) + else: + cfgfile = FileHandler("open", "config").retfile + if not cfgfile: + return + cfg = self.serializer.unmarshal(cfgfile.read()) + + if not command and len(cfg.keys()) == 1: + command = list(cfg.keys())[0] + + opts = self.get_command_options(cfg, command) if command else cfg + if not opts: + return + + for cmd, opts in opts.items(): + self.set_command_args(cmd, opts) + + if command: + self.command_notebook.select(self.command_tabs[command]) + + self.add_to_recent(cfgfile.name, command) + logger.debug("Loaded config: (command: '%s', cfgfile: '%s')", command, cfgfile) + + def get_command_options(self, cfg, command): + """ return the saved options for the requested + command, if not loading global options """ + opts = cfg.get(command, None) + retval = {command: opts} + if not opts: + self.tk_vars["consoleclear"].set(True) + print("No {} section found in file".format(command)) + logger.info("No %s section found in file", command) + retval = None + logger.debug(retval) + return retval + + def set_command_args(self, command, options): + """ Pass the saved config items back to the CliOptions """ + if not options: + return + for srcopt, srcval in options.items(): + optvar = self.cli_opts.get_one_option_variable(command, srcopt) + if not optvar: + continue + optvar.set(srcval) + + def save(self, command=None): + """ Save the current GUI state to a config file in json format """ + logger.debug("Saving config: (command: '%s')", command) + cfgfile = FileHandler("save", "config").retfile + if not cfgfile: + return + cfg = self.cli_opts.get_option_values(command) + cfgfile.write(self.serializer.marshal(cfg)) + cfgfile.close() + self.add_to_recent(cfgfile.name, command) + logger.debug("Saved config: (command: '%s', cfgfile: '%s')", command, cfgfile) + + def add_to_recent(self, filename, command): + """ Add to recent files """ + recent_filename = os.path.join(self.pathcache, ".recent.json") + logger.debug("Adding to recent files '%s': (%s, %s)", recent_filename, filename, command) + with open(recent_filename, "rb") as inp: + recent_files = self.serializer.unmarshal(inp.read().decode("utf-8")) + logger.debug("Initial recent files: %s", recent_files) + filenames = [recent[0] for recent in recent_files] + if filename in filenames: + idx = filenames.index(filename) + del recent_files[idx] + recent_files.insert(0, (filename, command)) + recent_files = recent_files[:20] + logger.debug("Final recent files: %s", recent_files) + recent_json = self.serializer.marshal(recent_files) + with open(recent_filename, "wb") as out: + out.write(recent_json.encode("utf-8")) diff --git a/lib/gui/wrapper.py b/lib/gui/wrapper.py index 39b16986ef..67e41fdd93 100644 --- a/lib/gui/wrapper.py +++ b/lib/gui/wrapper.py @@ -1,57 +1,40 @@ #!/usr/bin python3 """ Process wrapper for underlying faceswap commands for the GUI """ import os +import logging import re import signal from subprocess import PIPE, Popen, TimeoutExpired import sys -import tkinter as tk from threading import Thread from time import time import psutil -from .utils import Images +from .utils import get_config, get_images + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class ProcessWrapper(): """ Builds command, launches and terminates the underlying faceswap process. Updates GUI display depending on state """ - def __init__(self, statusbar, session=None, pathscript=None, cliopts=None): - self.tk_vars = self.set_tk_vars() - self.session = session + def __init__(self, pathscript=None): + logger.debug("Initializing %s: (pathscript: %s)", self.__class__.__name__, pathscript) + self.tk_vars = get_config().tk_vars + self.set_callbacks() self.pathscript = pathscript - self.cliopts = cliopts self.command = None - self.statusbar = statusbar + self.statusbar = get_config().statusbar self.task = FaceswapControl(self) + logger.debug("Initialized %s", self.__class__.__name__) - def set_tk_vars(self): - """ TK Variables to be triggered by ProcessWrapper to indicate - what state various parts of the GUI should be in """ - display = tk.StringVar() - display.set(None) - - runningtask = tk.BooleanVar() - runningtask.set(False) - - actioncommand = tk.StringVar() - actioncommand.set(None) - actioncommand.trace("w", self.action_command) - - generatecommand = tk.StringVar() - generatecommand.set(None) - generatecommand.trace("w", self.generate_command) - - consoleclear = tk.BooleanVar() - consoleclear.set(False) - - return {"display": display, - "runningtask": runningtask, - "action": actioncommand, - "generate": generatecommand, - "consoleclear": consoleclear} + def set_callbacks(self): + """ Set the tk variable callbacks """ + logger.debug("Setting tk variable traces") + self.tk_vars["action"].trace("w", self.action_command) + self.tk_vars["generate"].trace("w", self.generate_command) def action_command(self, *args): """ The action to perform when the action button is pressed """ @@ -74,28 +57,30 @@ def generate_command(self, *args): category, command = self.tk_vars["generate"].get().split(",") args = self.build_args(category, command=command, generate=True) self.tk_vars["consoleclear"].set(True) + logger.debug(" ".join(args)) print(" ".join(args)) self.tk_vars["generate"].set(None) def prepare(self, category): """ Prepare the environment for execution """ + logger.debug("Preparing for execution") self.tk_vars["runningtask"].set(True) self.tk_vars["consoleclear"].set(True) print("Loading...") - self.statusbar.status_message.set("Executing - " - + self.command + ".py") - mode = "indeterminate" if self.command in ("effmpeg", - "train") else "determinate" + self.statusbar.status_message.set("Executing - {}.py".format(self.command)) + mode = "indeterminate" if self.command in ("effmpeg", "train") else "determinate" self.statusbar.progress_start(mode) args = self.build_args(category) self.tk_vars["display"].set(self.command) - + logger.debug("Prepared for execution") return args def build_args(self, category, command=None, generate=False): """ Build the faceswap command and arguments list """ + logger.debug("Build cli arguments: (category: %s, command: %s, generate: %s)", + category, command, generate) command = self.command if not command else command script = "{}.{}".format(category, "py") pathexecscript = os.path.join(self.pathscript, script) @@ -103,50 +88,60 @@ def build_args(self, category, command=None, generate=False): args = [sys.executable] if generate else [sys.executable, "-u"] args.extend([pathexecscript, command]) - for cliopt in self.cliopts.gen_cli_arguments(command): + cli_opts = get_config().cli_opts + for cliopt in cli_opts.gen_cli_arguments(command): args.extend(cliopt) if command == "train" and not generate: - self.set_session_stats(cliopt) - if command == "train" and not generate: - args.append("-gui") # Embed the preview pane + self.init_training_session(cliopt) + if not generate: + args.append("-gui") # Indicate to Faceswap that we are running the GUI + logger.debug("Built cli arguments: (%s)", args) return args - def set_session_stats(self, cliopt): - """ Set the session stats for batchsize and modeldir """ - if cliopt[0] == "-bs": - self.session.stats["batchsize"] = int(cliopt[1]) + @staticmethod + def init_training_session(cliopt): + """ Set the session stats for disable logging, model folder and model name """ + session = get_config().session + if cliopt[0] == "-t": + session.modelname = cliopt[1].lower().replace("-", "_") + logger.debug("modelname: '%s'", session.modelname) if cliopt[0] == "-m": - self.session.modeldir = cliopt[1] + session.modeldir = cliopt[1] + logger.debug("modeldir: '%s'", session.modeldir) def terminate(self, message): - """ Finalise wrapper when process has exited """ + """ Finalize wrapper when process has exited """ + logger.debug("Terminating Faceswap processes") self.tk_vars["runningtask"].set(False) self.statusbar.progress_stop() self.statusbar.status_message.set(message) self.tk_vars["display"].set(None) - Images().delete_preview() - if self.command == "train": - self.session.save_session() - self.session.__init__() + get_images().delete_preview() + get_config().session.__init__() self.command = None + logger.debug("Terminated Faceswap processes") print("Process exited.") class FaceswapControl(): """ Control the underlying Faceswap tasks """ def __init__(self, wrapper): - + logger.debug("Initializing %s", self.__class__.__name__) self.wrapper = wrapper - self.statusbar = wrapper.statusbar + self.statusbar = get_config().statusbar self.command = None self.args = None self.process = None + self.train_stats = {"iterations": 0, "timestamp": None} self.consoleregex = { "loss": re.compile(r"([a-zA-Z_]+):.*?(\d+\.\d+)"), - "tqdm": re.compile(r"(\d+%|\d+/\d+|\d+:\d+|\d+\.\d+[a-zA-Z/]+)")} + "tqdm": re.compile(r".*?(?P\d+%).*?(?P\d+/\d+)\W\[" + r"(?P\d+:\d+<.*),\W(?P.*)[a-zA-Z/]*\]")} + logger.debug("Initialized %s", self.__class__.__name__) def execute_script(self, command, args): """ Execute the requested Faceswap Script """ + logger.debug("Executing Faceswap: (command: '%s', args: %s)", command, args) self.command = command kwargs = {"stdout": PIPE, "stderr": PIPE, @@ -156,10 +151,12 @@ def execute_script(self, command, args): self.process = Popen(args, **kwargs, stdin=PIPE) self.thread_stdout() self.thread_stderr() + logger.debug("Executed Faceswap") def read_stdout(self): """ Read stdout from the subprocess. If training, pass the loss values to Queue """ + logger.debug("Opening stdout reader") while True: try: output = self.process.stdout.readline() @@ -173,14 +170,19 @@ def read_stdout(self): if (self.command == "train" and self.capture_loss(output)) or ( self.command != "train" and self.capture_tqdm(output)): continue + if self.command == "train" and output.strip().endswith("saved models"): + logger.debug("Trigger update preview") + self.wrapper.tk_vars["updatepreview"].set(True) print(output.strip()) returncode = self.process.poll() message = self.set_final_status(returncode) self.wrapper.terminate(message) + logger.debug("Terminated stdout reader. returncode: %s", returncode) def read_stderr(self): """ Read stdout from the subprocess. If training, pass the loss values to Queue """ + logger.debug("Opening stderr reader") while True: try: output = self.process.stderr.readline() @@ -194,81 +196,125 @@ def read_stderr(self): if self.command != "train" and self.capture_tqdm(output): continue print(output.strip(), file=sys.stderr) + logger.debug("Terminated stderr reader") def thread_stdout(self): """ Put the subprocess stdout so that it can be read without blocking """ + logger.debug("Threading stdout") thread = Thread(target=self.read_stdout) thread.daemon = True thread.start() + logger.debug("Threaded stdout") def thread_stderr(self): """ Put the subprocess stderr so that it can be read without blocking """ + logger.debug("Threading stderr") thread = Thread(target=self.read_stderr) thread.daemon = True thread.start() + logger.debug("Threaded stderr") def capture_loss(self, string): """ Capture loss values from stdout """ - + logger.trace("Capturing loss") if not str.startswith(string, "["): + logger.trace("Not loss message. Returning False") return False loss = self.consoleregex["loss"].findall(string) if len(loss) < 2: + logger.trace("Not loss message. Returning False") return False - self.wrapper.session.add_loss(loss) - message = "" for item in loss: message += "{}: {} ".format(item[0], item[1]) if not message: + logger.trace("Error creating loss message. Returning False") return False - elapsed = self.wrapper.session.timestats["elapsed"] - iterations = self.wrapper.session.stats["iterations"] + iterations = self.train_stats["iterations"] + + if iterations == 0: + # Initialize session stats and set initial timestamp + self.train_stats["timestamp"] = time() + + if not get_config().session.initialized and iterations > 0: + # Don't initialize session until after the first iteration as state + # file must exist first + get_config().session.initialize_session(is_training=True) + self.wrapper.tk_vars["refreshgraph"].set(True) + iterations += 1 + if iterations % 100 == 0: + self.wrapper.tk_vars["refreshgraph"].set(True) + self.train_stats["iterations"] = iterations + + elapsed = self.calc_elapsed() message = "Elapsed: {} Iteration: {} {}".format(elapsed, - iterations, - message) + self.train_stats["iterations"], message) self.statusbar.progress_update(message, 0, False) + logger.trace("Succesfully captured loss: %s", message) return True + def calc_elapsed(self): + """ Calculate and format time since training started """ + now = time() + elapsed_time = now - self.train_stats["timestamp"] + try: + hrs = int(elapsed_time // 3600) + if hrs < 10: + hrs = "{0:02d}".format(hrs) + mins = "{0:02d}".format((int(elapsed_time % 3600) // 60)) + secs = "{0:02d}".format((int(elapsed_time % 3600) % 60)) + except ZeroDivisionError: + hrs = "00" + mins = "00" + secs = "00" + return "{}:{}:{}".format(hrs, mins, secs) + def capture_tqdm(self, string): """ Capture tqdm output for progress bar """ - tqdm = self.consoleregex["tqdm"].findall(string) - if len(tqdm) != 5: + logger.trace("Capturing tqdm") + tqdm = self.consoleregex["tqdm"].match(string) + if not tqdm: return False - - percent = tqdm[0] - processed = tqdm[1] - processtime = "Elapsed: {} Remaining: {}".format(tqdm[2], tqdm[3]) - rate = tqdm[4] + tqdm = tqdm.groupdict() + if any("?" in val for val in tqdm.values()): + logger.trace("tqdm initializing. Skipping") + return True + processtime = "Elapsed: {} Remaining: {}".format(tqdm["tme"].split("<")[0], + tqdm["tme"].split("<")[1]) message = "{} | {} | {} | {}".format(processtime, - rate, - processed, - percent) + tqdm["rte"], + tqdm["itm"], + tqdm["pct"]) - current, total = processed.split("/") + current, total = tqdm["itm"].split("/") position = int((float(current) / float(total)) * 1000) self.statusbar.progress_update(message, position, True) + logger.trace("Succesfully captured tqdm message: %s", message) return True def terminate(self): """ Terminate the subprocess """ - if self.command != "train": + logger.debug("Terminating wrapper") + if self.command == "train": + logger.debug("Sending Exit Signal") print("Sending Exit Signal", flush=True) try: now = time() if os.name == "nt": try: + logger.debug("Sending carriage return to process") self.process.communicate(input="\n", timeout=60) except TimeoutExpired: raise ValueError("Timeout reached sending Exit Signal") else: + logger.debug("Sending SIGINT to process") self.process.send_signal(signal.SIGINT) while True: timeelapsed = time() - now @@ -278,30 +324,37 @@ def terminate(self): raise ValueError("Timeout reached sending Exit Signal") return except ValueError as err: + logger.error("Error terminating process", exc_info=True) print(err) else: + logger.debug("Terminating Process...") print("Terminating Process...") children = psutil.Process().children(recursive=True) for child in children: child.terminate() _, alive = psutil.wait_procs(children, timeout=10) if not alive: + logger.debug("Terminated") print("Terminated") return + logger.debug("Termination timed out. Killing Process...") print("Termination timed out. Killing Process...") for child in alive: child.kill() _, alive = psutil.wait_procs(alive, timeout=10) if not alive: + logger.debug("Killed") print("Killed") else: for child in alive: - print("Process {} survived SIGKILL. " - "Giving up".format(child)) + msg = "Process {} survived SIGKILL. Giving up".format(child) + logger.debug(msg) + print(msg) def set_final_status(self, returncode): """ Set the status bar output based on subprocess return code """ + logger.debug("Setting final status. returncode: %s", returncode) if returncode in (0, 3221225786): status = "Ready" elif returncode == -15: @@ -311,6 +364,6 @@ def set_final_status(self, returncode): elif returncode == -6: status = "Aborted - {}.py".format(self.command) else: - status = "Failed - {}.py. Return Code: {}".format(self.command, - returncode) + status = "Failed - {}.py. Return Code: {}".format(self.command, returncode) + logger.debug("Set final status: %s", status) return status diff --git a/lib/keypress.py b/lib/keypress.py new file mode 100644 index 0000000000..a36a049600 --- /dev/null +++ b/lib/keypress.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +""" +Source: http://home.wlu.edu/~levys/software/kbhit.py +A Python class implementing KBHIT, the standard keyboard-interrupt poller. +Works transparently on Windows and Posix (Linux, Mac OS X). Doesn't work +with IDLE. + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Lesser General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. +""" + +import os + +# Windows +if os.name == "nt": + import msvcrt # pylint: disable=import-error + +# Posix (Linux, OS X) +else: + import sys + import termios + import atexit + from select import select + + +class KBHit: + """ Creates a KBHit object that you can call to do various keyboard things. """ + def __init__(self, is_gui=False): + self.is_gui = is_gui + if os.name == "nt" or self.is_gui: + pass + else: + # Save the terminal settings + self.file_desc = sys.stdin.fileno() + self.new_term = termios.tcgetattr(self.file_desc) + self.old_term = termios.tcgetattr(self.file_desc) + + # New terminal setting unbuffered + self.new_term[3] = (self.new_term[3] & ~termios.ICANON & ~termios.ECHO) + termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.new_term) + + # Support normal-terminal reset at exit + atexit.register(self.set_normal_term) + + def set_normal_term(self): + """ Resets to normal terminal. On Windows this is a no-op. """ + if os.name == "nt" or self.is_gui: + pass + else: + termios.tcsetattr(self.file_desc, termios.TCSAFLUSH, self.old_term) + + @staticmethod + def getch(): + """ Returns a keyboard character after kbhit() has been called. + Should not be called in the same program as getarrow(). """ + if os.name == "nt": + return msvcrt.getch().decode("utf-8") + return sys.stdin.read(1) + + @staticmethod + def getarrow(): + """ Returns an arrow-key code after kbhit() has been called. Codes are + 0 : up + 1 : right + 2 : down + 3 : left + Should not be called in the same program as getch(). """ + + if os.name == "nt": + msvcrt.getch() # skip 0xE0 + char = msvcrt.getch() + vals = [72, 77, 80, 75] + else: + char = sys.stdin.read(3)[2] + vals = [65, 67, 66, 68] + + return vals.index(ord(char.decode("utf-8"))) + + @staticmethod + def kbhit(): + """ Returns True if keyboard character was hit, False otherwise. """ + if os.name == "nt": + return msvcrt.kbhit() + d_r, _, _ = select([sys.stdin], [], [], 0) + return d_r != [] diff --git a/lib/logger.py b/lib/logger.py index fc0b1f40f2..0e64687bbc 100644 --- a/lib/logger.py +++ b/lib/logger.py @@ -44,9 +44,15 @@ def trace(self, msg, *args, **kwargs): class FaceswapFormatter(logging.Formatter): - """ Override formatter to strip newlines and multiple spaces from logger """ + """ Override formatter to strip newlines and multiple spaces from logger + Messages that begin with "R|" should be handled as is + """ def format(self, record): - record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r")) + if record.msg.startswith("R|"): + record.msg = record.msg[2:] + record.strip_spaces = False + elif record.strip_spaces: + record.msg = re.sub(" +", " ", record.msg.replace("\n", "\\n").replace("\r", "\\r")) return super().format(record) @@ -92,7 +98,7 @@ def file_handler(loglevel, logfile, log_format, command): filename = logfile else: filename = os.path.join(os.path.dirname(os.path.realpath(sys.argv[0])), "faceswap") - # Windows has issues sharing the log file with subprocesses, so log GUI seperately + # Windows has issues sharing the log file with subprocesses, so log GUI separately filename += "_gui.log" if command == "gui" else ".log" should_rotate = os.path.isfile(filename) @@ -152,6 +158,18 @@ def crash_log(): return filename +# Add a flag to logging.LogRecord to not strip formatting from particular records +old_factory = logging.getLogRecordFactory() + + +def faceswap_logrecord(*args, **kwargs): + record = old_factory(*args, **kwargs) + record.strip_spaces = True + return record + + +logging.setLogRecordFactory(faceswap_logrecord) + # Set logger class to custom logger logging.setLoggerClass(MultiProcessingLogger) diff --git a/lib/model/__init__.py b/lib/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/model/initializers.py b/lib/model/initializers.py new file mode 100644 index 0000000000..c2b4ce6043 --- /dev/null +++ b/lib/model/initializers.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" Custom Initializers for faceswap.py + Initializers from: + shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" + +import sys +import inspect +import tensorflow as tf +from keras import initializers +from keras.utils.generic_utils import get_custom_objects + + +def icnr_keras(shape, dtype=None): + """ + Custom initializer for subpix upscaling + From https://github.com/kostyaev/ICNR + Note: upscale factor is fixed to 2, and the base initializer is fixed to random normal. + """ + # TODO Roll this into ICNR_init when porting GAN 2.2 + shape = list(shape) + scale = 2 + initializer = tf.keras.initializers.RandomNormal(0, 0.02) + + new_shape = shape[:3] + [int(shape[3] / (scale ** 2))] + var_x = initializer(new_shape, dtype) + var_x = tf.transpose(var_x, perm=[2, 0, 1, 3]) + var_x = tf.image.resize_nearest_neighbor(var_x, size=(shape[0] * scale, shape[1] * scale)) + var_x = tf.space_to_depth(var_x, block_size=scale) + var_x = tf.transpose(var_x, perm=[1, 2, 0, 3]) + return var_x + + +class ICNR(initializers.Initializer): # pylint: disable=invalid-name + ''' + ICNR initializer for checkerboard artifact free sub pixel convolution + + Andrew Aitken et al. Checkerboard artifact free sub-pixel convolution + https://arxiv.org/pdf/1707.02937.pdf https://distill.pub/2016/deconv-checkerboard/ + + Parameters: + initializer: initializer used for sub kernels (orthogonal, glorot uniform, etc.) + scale: scale factor of sub pixel convolution (upsampling from 8x8 to 16x16 is scale 2) + Return: + The modified kernel weights + Example: + x = conv2d(... weights_initializer=ICNR(initializer=he_uniform(), scale=2)) + ''' + + def __init__(self, initializer, scale=2): + self.scale = scale + self.initializer = initializer + + def __call__(self, shape, dtype='float32'): # tf needs partition_info=None + shape = list(shape) + if self.scale == 1: + return self.initializer(shape) + new_shape = shape[:3] + [shape[3] // (self.scale ** 2)] + if type(self.initializer) is dict: + self.initializer = initializers.deserialize(self.initializer) + var_x = self.initializer(new_shape, dtype) + var_x = tf.transpose(var_x, perm=[2, 0, 1, 3]) + var_x = tf.image.resize_nearest_neighbor( + var_x, + size=(shape[0] * self.scale, shape[1] * self.scale), + align_corners=True) + var_x = tf.space_to_depth(var_x, block_size=self.scale, data_format='NHWC') + var_x = tf.transpose(var_x, perm=[1, 2, 0, 3]) + return var_x + + def get_config(self): + config = {'scale': self.scale, + 'initializer': self.initializer + } + base_config = super(ICNR, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +# Update initializers into Keras custom objects +for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and obj.__module__ == __name__: + get_custom_objects().update({name: obj}) diff --git a/lib/model/layers.py b/lib/model/layers.py new file mode 100644 index 0000000000..c848b7f23d --- /dev/null +++ b/lib/model/layers.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" Custom Layers for faceswap.py + Layers from: + the original https://www.reddit.com/r/deepfakes/ code sample + contribs + shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" + +from __future__ import absolute_import + +import sys +import inspect + +import tensorflow as tf +import keras.backend as K + +from keras.engine import InputSpec, Layer +from keras.utils import conv_utils +from keras.utils.generic_utils import get_custom_objects +from keras import initializers +from keras.layers import ZeroPadding2D + + +class PixelShuffler(Layer): + """ PixelShuffler layer for Keras + by t-ae: https://gist.github.com/t-ae/6e1016cc188104d123676ccef3264981 """ + # pylint: disable=C0103 + def __init__(self, size=(2, 2), data_format=None, **kwargs): + super(PixelShuffler, self).__init__(**kwargs) + self.data_format = K.normalize_data_format(data_format) + self.size = conv_utils.normalize_tuple(size, 2, 'size') + + def call(self, inputs, **kwargs): + + input_shape = K.int_shape(inputs) + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + + str(4) + + '; Received input shape:', str(input_shape)) + + if self.data_format == 'channels_first': + batch_size, c, h, w = input_shape + if batch_size is None: + batch_size = -1 + rh, rw = self.size + oh, ow = h * rh, w * rw + oc = c // (rh * rw) + + out = K.reshape(inputs, (batch_size, rh, rw, oc, h, w)) + out = K.permute_dimensions(out, (0, 3, 4, 1, 5, 2)) + out = K.reshape(out, (batch_size, oc, oh, ow)) + elif self.data_format == 'channels_last': + batch_size, h, w, c = input_shape + if batch_size is None: + batch_size = -1 + rh, rw = self.size + oh, ow = h * rh, w * rw + oc = c // (rh * rw) + + out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) + out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) + out = K.reshape(out, (batch_size, oh, ow, oc)) + return out + + def compute_output_shape(self, input_shape): + + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + + str(4) + + '; Received input shape:', str(input_shape)) + + if self.data_format == 'channels_first': + height = None + width = None + if input_shape[2] is not None: + height = input_shape[2] * self.size[0] + if input_shape[3] is not None: + width = input_shape[3] * self.size[1] + channels = input_shape[1] // self.size[0] // self.size[1] + + if channels * self.size[0] * self.size[1] != input_shape[1]: + raise ValueError('channels of input and size are incompatible') + + retval = (input_shape[0], + channels, + height, + width) + elif self.data_format == 'channels_last': + height = None + width = None + if input_shape[1] is not None: + height = input_shape[1] * self.size[0] + if input_shape[2] is not None: + width = input_shape[2] * self.size[1] + channels = input_shape[3] // self.size[0] // self.size[1] + + if channels * self.size[0] * self.size[1] != input_shape[3]: + raise ValueError('channels of input and size are incompatible') + + retval = (input_shape[0], + height, + width, + channels) + return retval + + def get_config(self): + config = {'size': self.size, + 'data_format': self.data_format} + base_config = super(PixelShuffler, self).get_config() + + return dict(list(base_config.items()) + list(config.items())) + + +class Scale(Layer): + """ + GAN Custom Scal Layer + Code borrows from https://github.com/flyyufelix/cnn_finetune + """ + def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs): + self.axis = axis + self.gamma_init = initializers.get(gamma_init) + self.initial_weights = weights + super(Scale, self).__init__(**kwargs) + + def build(self, input_shape): + self.input_spec = [InputSpec(shape=input_shape)] + + # Compatibility with TensorFlow >= 1.0.0 + self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name)) + self.trainable_weights = [self.gamma] + + if self.initial_weights is not None: + self.set_weights(self.initial_weights) + del self.initial_weights + + def call(self, x, mask=None): + return self.gamma * x + + def get_config(self): + config = {"axis": self.axis} + base_config = super(Scale, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class SubPixelUpscaling(Layer): + # pylint: disable=C0103 + """ Sub-pixel convolutional upscaling layer based on the paper "Real-Time + Single Image and Video Super-Resolution Using an Efficient Sub-Pixel + Convolutional Neural Network" (https://arxiv.org/abs/1609.05158). + This layer requires a Convolution2D prior to it, having output filters + computed according to the formula : + filters = k * (scale_factor * scale_factor) + where k = a user defined number of filters (generally larger than 32) + scale_factor = the upscaling factor (generally 2) + This layer performs the depth to space operation on the convolution + filters, and returns a tensor with the size as defined below. + # Example : + ```python + # A standard subpixel upscaling block + x = Convolution2D(256, 3, 3, padding="same", activation="relu")(...) + u = SubPixelUpscaling(scale_factor=2)(x) + [Optional] + x = Convolution2D(256, 3, 3, padding="same", activation="relu")(u) + ``` + In practice, it is useful to have a second convolution layer after the + SubPixelUpscaling layer to speed up the learning process. + However, if you are stacking multiple SubPixelUpscaling blocks, + it may increase the number of parameters greatly, so the Convolution + layer after SubPixelUpscaling layer can be removed. + # Arguments + scale_factor: Upscaling factor. + data_format: Can be None, "channels_first" or "channels_last". + # Input shape + 4D tensor with shape: + `(samples, k * (scale_factor * scale_factor) channels, rows, cols)` + if data_format="channels_first" + or 4D tensor with shape: + `(samples, rows, cols, k * (scale_factor * scale_factor) channels)` + if data_format="channels_last". + # Output shape + 4D tensor with shape: + `(samples, k channels, rows * scale_factor, cols * scale_factor))` + if data_format="channels_first" + or 4D tensor with shape: + `(samples, rows * scale_factor, cols * scale_factor, k channels)` + if data_format="channels_last". + """ + + def __init__(self, scale_factor=2, data_format=None, **kwargs): + super(SubPixelUpscaling, self).__init__(**kwargs) + + self.scale_factor = scale_factor + self.data_format = K.normalize_data_format(data_format) + + def build(self, input_shape): + pass + + def call(self, x, mask=None): + y = self.depth_to_space(x, self.scale_factor, self.data_format) + return y + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + b, k, r, c = input_shape + return (b, + k // (self.scale_factor ** 2), + r * self.scale_factor, + c * self.scale_factor) + b, r, c, k = input_shape + return (b, + r * self.scale_factor, + c * self.scale_factor, + k // (self.scale_factor ** 2)) + + @classmethod + def depth_to_space(cls, ipt, scale, data_format=None): + """ Uses phase shift algorithm to convert channels/depth + for spatial resolution """ + if data_format is None: + data_format = K.image_data_format() + data_format = data_format.lower() + ipt = cls._preprocess_conv2d_input(ipt, data_format) + out = tf.depth_to_space(ipt, scale) + out = cls._postprocess_conv2d_output(out, data_format) + return out + + @staticmethod + def _postprocess_conv2d_output(x, data_format): + """Transpose and cast the output from conv2d if needed. + # Arguments + x: A tensor. + data_format: string, `"channels_last"` or `"channels_first"`. + # Returns + A tensor. + """ + + if data_format == "channels_first": + x = tf.transpose(x, (0, 3, 1, 2)) + + if K.floatx() == "float64": + x = tf.cast(x, "float64") + return x + + @staticmethod + def _preprocess_conv2d_input(x, data_format): + """Transpose and cast the input before the conv2d. + # Arguments + x: input tensor. + data_format: string, `"channels_last"` or `"channels_first"`. + # Returns + A tensor. + """ + if K.dtype(x) == "float64": + x = tf.cast(x, "float32") + if data_format == "channels_first": + # TF uses the last dimension as channel dimension, + # instead of the 2nd one. + # TH input shape: (samples, input_depth, rows, cols) + # TF input shape: (samples, rows, cols, input_depth) + x = tf.transpose(x, (0, 2, 3, 1)) + return x + + def get_config(self): + config = {"scale_factor": self.scale_factor, + "data_format": self.data_format} + base_config = super(SubPixelUpscaling, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class ReflectionPadding2D(Layer): + def __init__(self, stride=2, kernel_size=5, **kwargs): + ''' + # Arguments + stride: stride of following convolution (2) + kernel_size: kernel size of following convolution (5,5) + ''' + self.stride = stride + self.kernel_size = kernel_size + super(ReflectionPadding2D, self).__init__(**kwargs) + + def build(self, input_shape): + self.input_spec = [InputSpec(shape=input_shape)] + super(ReflectionPadding2D, self).build(input_shape) + + def compute_output_shape(self, input_shape): + """ If you are using "channels_last" configuration""" + input_shape = self.input_spec[0].shape + in_width, in_height = input_shape[2], input_shape[1] + kernel_width, kernel_height = self.kernel_size, self.kernel_size + + if (in_height % self.stride == 0): + padding_height = max(kernel_height - self.stride, 0) + else: + padding_height = max(kernel_height - (in_height % self.stride), 0) + if (in_width % self.stride == 0): + padding_width = max(kernel_width - self.stride, 0) + else: + padding_width = max(kernel_width- (in_width % self.stride), 0) + + return (input_shape[0], + input_shape[1] + padding_height, + input_shape[2] + padding_width, + input_shape[3]) + + def call(self, x, mask=None): + input_shape = self.input_spec[0].shape + in_width, in_height = input_shape[2], input_shape[1] + kernel_width, kernel_height = self.kernel_size, self.kernel_size + + if (in_height % self.stride == 0): + padding_height = max(kernel_height - self.stride, 0) + else: + padding_height = max(kernel_height - (in_height % self.stride), 0) + if (in_width % self.stride == 0): + padding_width = max(kernel_width - self.stride, 0) + else: + padding_width = max(kernel_width- (in_width % self.stride), 0) + + padding_top = padding_height // 2 + padding_bot = padding_height - padding_top + padding_left = padding_width // 2 + padding_right = padding_width - padding_left + + return tf.pad(x, [[0,0], + [padding_top, padding_bot], + [padding_left, padding_right], + [0,0] ], + 'REFLECT') + + def get_config(self): + config = {'stride': self.stride, + 'kernel_size': self.kernel_size} + base_config = super(ReflectionPadding2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +# Update layers into Keras custom objects +for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and obj.__module__ == __name__: + get_custom_objects().update({name: obj}) diff --git a/lib/model/losses.py b/lib/model/losses.py new file mode 100644 index 0000000000..b54e2bc1d6 --- /dev/null +++ b/lib/model/losses.py @@ -0,0 +1,844 @@ +#!/usr/bin/env python3 +""" Custom Loss Functions for faceswap.py + Losses from: + keras.contrib + dfaker: https://github.com/dfaker/df + shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" + +from __future__ import absolute_import + + +import keras.backend as K +from keras.layers import Lambda, concatenate +import tensorflow as tf +from tensorflow.contrib.distributions import Beta + +from .normalization import InstanceNormalization + + +class DSSIMObjective(): + """ DSSIM Loss Function + + Code copy and pasted, with minor ammendments from: + https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py + + MIT License + + Copyright (c) 2017 Fariz Rahman + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. """ + # pylint: disable=C0103 + def __init__(self, k1=0.01, k2=0.03, kernel_size=3, max_value=1.0): + """ + Difference of Structural Similarity (DSSIM loss function). Clipped + between 0 and 0.5 + Note : You should add a regularization term like a l2 loss in + addition to this one. + Note : In theano, the `kernel_size` must be a factor of the output + size. So 3 could not be the `kernel_size` for an output of 32. + # Arguments + k1: Parameter of the SSIM (default 0.01) + k2: Parameter of the SSIM (default 0.03) + kernel_size: Size of the sliding window (default 3) + max_value: Max value of the output (default 1.0) + """ + self.__name__ = 'DSSIMObjective' + self.kernel_size = kernel_size + self.k1 = k1 + self.k2 = k2 + self.max_value = max_value + self.c1 = (self.k1 * self.max_value) ** 2 + self.c2 = (self.k2 * self.max_value) ** 2 + self.dim_ordering = K.image_data_format() + self.backend = K.backend() + + @staticmethod + def __int_shape(x): + return K.int_shape(x) + + def __call__(self, y_true, y_pred): + # There are additional parameters for this function + # Note: some of the 'modes' for edge behavior do not yet have a + # gradient definition in the Theano tree and cannot be used for + # learning + + kernel = [self.kernel_size, self.kernel_size] + y_true = K.reshape(y_true, [-1] + list(self.__int_shape(y_pred)[1:])) + y_pred = K.reshape(y_pred, [-1] + list(self.__int_shape(y_pred)[1:])) + + patches_pred = self.extract_image_patches(y_pred, + kernel, + kernel, + 'valid', + self.dim_ordering) + patches_true = self.extract_image_patches(y_true, + kernel, + kernel, + 'valid', + self.dim_ordering) + + # Reshape to get the var in the cells + _, w, h, c1, c2, c3 = self.__int_shape(patches_pred) + patches_pred = K.reshape(patches_pred, [-1, w, h, c1 * c2 * c3]) + patches_true = K.reshape(patches_true, [-1, w, h, c1 * c2 * c3]) + # Get mean + u_true = K.mean(patches_true, axis=-1) + u_pred = K.mean(patches_pred, axis=-1) + # Get variance + var_true = K.var(patches_true, axis=-1) + var_pred = K.var(patches_pred, axis=-1) + # Get std dev + covar_true_pred = K.mean( + patches_true * patches_pred, axis=-1) - u_true * u_pred + + ssim = (2 * u_true * u_pred + self.c1) * ( + 2 * covar_true_pred + self.c2) + denom = (K.square(u_true) + K.square(u_pred) + self.c1) * ( + var_pred + var_true + self.c2) + ssim /= denom # no need for clipping, c1 + c2 make the denom non-zero + return K.mean((1.0 - ssim) / 2.0) + + @staticmethod + def _preprocess_padding(padding): + """Convert keras' padding to tensorflow's padding. + # Arguments + padding: string, `"same"` or `"valid"`. + # Returns + a string, `"SAME"` or `"VALID"`. + # Raises + ValueError: if `padding` is invalid. + """ + if padding == 'same': + padding = 'SAME' + elif padding == 'valid': + padding = 'VALID' + else: + raise ValueError('Invalid padding:', padding) + return padding + + def extract_image_patches(self, x, ksizes, ssizes, padding='same', + data_format='channels_last'): + ''' + Extract the patches from an image + # Parameters + x : The input image + ksizes : 2-d tuple with the kernel size + ssizes : 2-d tuple with the strides size + padding : 'same' or 'valid' + data_format : 'channels_last' or 'channels_first' + # Returns + The (k_w,k_h) patches extracted + TF ==> (batch_size,w,h,k_w,k_h,c) + TH ==> (batch_size,w,h,c,k_w,k_h) + ''' + kernel = [1, ksizes[0], ksizes[1], 1] + strides = [1, ssizes[0], ssizes[1], 1] + padding = self._preprocess_padding(padding) + if data_format == 'channels_first': + x = K.permute_dimensions(x, (0, 2, 3, 1)) + _, _, _, ch_i = K.int_shape(x) + patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1], + padding) + # Reshaping to fit Theano + _, w, h, ch = K.int_shape(patches) + patches = tf.reshape(tf.transpose(tf.reshape(patches, + [-1, w, h, + tf.floordiv(ch, ch_i), + ch_i]), + [0, 1, 2, 4, 3]), + [-1, w, h, ch_i, ksizes[0], ksizes[1]]) + if data_format == 'channels_last': + patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3]) + return patches + +# <<< START: from Dfaker >>> # +class PenalizedLoss(): # pylint: disable=too-few-public-methods + """ Penalized Loss + from: https://github.com/dfaker/df """ + def __init__(self, mask, loss_func, mask_prop=1.0): + self.mask = mask + self.loss_func = loss_func + self.mask_prop = mask_prop + self.mask_as_k_inv_prop = 1-mask_prop + + def __call__(self, y_true, y_pred): + # pylint: disable=invalid-name + tro, tgo, tbo = tf.split(y_true, 3, 3) + pro, pgo, pbo = tf.split(y_pred, 3, 3) + + tr = tro + tg = tgo + tb = tbo + + pr = pro + pg = pgo + pb = pbo + m = self.mask + + m = m * self.mask_prop + m += self.mask_as_k_inv_prop + tr *= m + tg *= m + tb *= m + + pr *= m + pg *= m + pb *= m + + y = tf.concat([tr, tg, tb], 3) + p = tf.concat([pr, pg, pb], 3) + + # yo = tf.stack([tro,tgo,tbo],3) + # po = tf.stack([pro,pgo,pbo],3) + + return self.loss_func(y, p) +# <<< END: from Dfaker >>> # + + +# <<< START: from Shoanlu GAN >>> # +def first_order(var_x, axis=1): + """ First Order Function from Shoanlu GAN """ + img_nrows = var_x.shape[1] + img_ncols = var_x.shape[2] + if axis == 1: + return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, 1:, :img_ncols - 1, :]) + if axis == 2: + return K.abs(var_x[:, :img_nrows - 1, :img_ncols - 1, :] - var_x[:, :img_nrows - 1, 1:, :]) + return None + + +def calc_loss(pred, target, loss='l2'): + """ Calculate Loss from Shoanlu GAN """ + if loss.lower() == "l2": + return K.mean(K.square(pred - target)) + if loss.lower() == "l1": + return K.mean(K.abs(pred - target)) + if loss.lower() == "cross_entropy": + return -K.mean(K.log(pred + K.epsilon()) * target + + K.log(1 - pred + K.epsilon()) * (1 - target)) + raise ValueError('Recieve an unknown loss type: {}.'.format(loss)) + + +def cyclic_loss(net_g1, net_g2, real1): + """ Cyclic Loss Function from Shoanlu GAN """ + fake2 = net_g2(real1)[-1] # fake2 ABGR + fake2 = Lambda(lambda x: x[:, :, :, 1:])(fake2) # fake2 BGR + cyclic1 = net_g1(fake2)[-1] # cyclic1 ABGR + cyclic1 = Lambda(lambda x: x[:, :, :, 1:])(cyclic1) # cyclic1 BGR + loss = calc_loss(cyclic1, real1, loss='l1') + return loss + + +def adversarial_loss(net_d, real, fake_abgr, distorted, gan_training="mixup_LSGAN", **weights): + """ Adversarial Loss Function from Shoanlu GAN """ + alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) + fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) + fake = alpha * fake_bgr + (1-alpha) * distorted + + if gan_training == "mixup_LSGAN": + dist = Beta(0.2, 0.2) + lam = dist.sample() + mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted]) + pred_fake = net_d(concatenate([fake, distorted])) + pred_mixup = net_d(mixup) + loss_d = calc_loss(pred_mixup, lam * K.ones_like(pred_mixup), "l2") + loss_g = weights['w_D'] * calc_loss(pred_fake, K.ones_like(pred_fake), "l2") + mixup2 = lam * concatenate([real, + distorted]) + (1 - lam) * concatenate([fake_bgr, + distorted]) + pred_fake_bgr = net_d(concatenate([fake_bgr, distorted])) + pred_mixup2 = net_d(mixup2) + loss_d += calc_loss(pred_mixup2, lam * K.ones_like(pred_mixup2), "l2") + loss_g += weights['w_D'] * calc_loss(pred_fake_bgr, K.ones_like(pred_fake_bgr), "l2") + elif gan_training == "relativistic_avg_LSGAN": + real_pred = net_d(concatenate([real, distorted])) + fake_pred = net_d(concatenate([fake, distorted])) + loss_d = K.mean(K.square(real_pred - K.ones_like(fake_pred)))/2 + loss_d += K.mean(K.square(fake_pred - K.zeros_like(fake_pred)))/2 + loss_g = weights['w_D'] * K.mean(K.square(fake_pred - K.ones_like(fake_pred))) + + fake_pred2 = net_d(concatenate([fake_bgr, distorted])) + loss_d += K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) - + K.ones_like(fake_pred2)))/2 + loss_d += K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) - + K.zeros_like(fake_pred2)))/2 + loss_g += weights['w_D'] * K.mean(K.square(real_pred - K.mean(fake_pred2, axis=0) - + K.zeros_like(fake_pred2)))/2 + loss_g += weights['w_D'] * K.mean(K.square(fake_pred2 - K.mean(real_pred, axis=0) - + K.ones_like(fake_pred2)))/2 + else: + raise ValueError("Receive an unknown GAN training method: {gan_training}") + return loss_d, loss_g + + +def reconstruction_loss(real, fake_abgr, mask_eyes, model_outputs, **weights): + """ Reconstruction Loss Function from Shoanlu GAN """ + alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) + fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) + + loss_g = 0 + loss_g += weights['w_recon'] * calc_loss(fake_bgr, real, "l1") + loss_g += weights['w_eyes'] * K.mean(K.abs(mask_eyes*(fake_bgr - real))) + + for out in model_outputs[:-1]: + out_size = out.get_shape().as_list() + resized_real = tf.image.resize_images(real, out_size[1:3]) + loss_g += weights['w_recon'] * calc_loss(out, resized_real, "l1") + return loss_g + + +def edge_loss(real, fake_abgr, mask_eyes, **weights): + """ Edge Loss Function from Shoanlu GAN """ + alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) + fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) + + loss_g = 0 + loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=1), + first_order(real, axis=1), "l1") + loss_g += weights['w_edge'] * calc_loss(first_order(fake_bgr, axis=2), + first_order(real, axis=2), "l1") + shape_mask_eyes = mask_eyes.get_shape().as_list() + resized_mask_eyes = tf.image.resize_images(mask_eyes, + [shape_mask_eyes[1]-1, shape_mask_eyes[2]-1]) + loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes * + (first_order(fake_bgr, axis=1) - + first_order(real, axis=1)))) + loss_g += weights['w_eyes'] * K.mean(K.abs(resized_mask_eyes * + (first_order(fake_bgr, axis=2) - + first_order(real, axis=2)))) + return loss_g + + +def perceptual_loss(real, fake_abgr, distorted, mask_eyes, vggface_feats, **weights): + """ Perceptual Loss Function from Shoanlu GAN """ + alpha = Lambda(lambda x: x[:, :, :, :1])(fake_abgr) + fake_bgr = Lambda(lambda x: x[:, :, :, 1:])(fake_abgr) + fake = alpha * fake_bgr + (1-alpha) * distorted + + def preprocess_vggface(var_x): + var_x = (var_x + 1) / 2 * 255 # channel order: BGR + var_x -= [91.4953, 103.8827, 131.0912] + return var_x + + real_sz224 = tf.image.resize_images(real, [224, 224]) + real_sz224 = Lambda(preprocess_vggface)(real_sz224) + dist = Beta(0.2, 0.2) + lam = dist.sample() # use mixup trick here to reduce foward pass from 2 times to 1. + mixup = lam*fake_bgr + (1-lam)*fake + fake_sz224 = tf.image.resize_images(mixup, [224, 224]) + fake_sz224 = Lambda(preprocess_vggface)(fake_sz224) + real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feats(real_sz224) + fake_feat112, fake_feat55, fake_feat28, fake_feat7 = vggface_feats(fake_sz224) + + # Apply instance norm on VGG(ResNet) features + # From MUNIT https://github.com/NVlabs/MUNIT + loss_g = 0 + + def instnorm(): + return InstanceNormalization() + + loss_g += weights['w_pl'][0] * calc_loss(instnorm()(fake_feat7), + instnorm()(real_feat7), "l2") + loss_g += weights['w_pl'][1] * calc_loss(instnorm()(fake_feat28), + instnorm()(real_feat28), "l2") + loss_g += weights['w_pl'][2] * calc_loss(instnorm()(fake_feat55), + instnorm()(real_feat55), "l2") + loss_g += weights['w_pl'][3] * calc_loss(instnorm()(fake_feat112), + instnorm()(real_feat112), "l2") + return loss_g + +# <<< END: from Shoanlu GAN >>> # + + +def generalized_loss_function(y_true, y_pred, a = 1.0, c=1.0/255.0): + ''' + generalized function used to return a large variety of mathematical loss functions + primary benefit is smooth, differentiable version of L1 loss + + Barron, J. A More General Robust Loss Function + https://arxiv.org/pdf/1701.03077.pdf + + Parameters: + a: penalty factor. larger number give larger weight to large deviations + c: scale factor used to adjust to the input scale (i.e. inputs of mean 1e-4 or 256 ) + + Return: + a loss value from the results of function(y_pred - y_true) + + Example: + a=1.0, x>>c , c=1.0/255.0 will give a smoothly differentiable version of L1 / MAE loss + a=1.999999 (lim as a->2), c=1.0/255.0 will give L2 / RMSE loss + ''' + x = y_pred - y_true + loss = (K.abs(2.0-a)/a) * ( K.pow( K.pow(x/c, 2.0)/K.abs(2.0-a) + 1.0 , (a/2.0)) - 1.0 ) + return K.mean(loss, axis=-1) * c + + +def staircase_loss(y_true, y_pred, a = 16.0, c=1.0/255.0): + h = c + w = c + x = K.clip(K.abs(y_true - y_pred) - 0.5 * c, 0.0, 1.0) + loss = h*( K.tanh(a*((x/w)-tf.floor(x/w)-0.5)) / ( 2.0*K.tanh(a/2.0) ) + 0.5 + tf.floor(x/w)) + loss += 1e-10 + return K.mean(loss, axis=-1) + + +def gradient_loss(y_true, y_pred): + ''' + Calculates the first and second order gradient difference between pixels of an image in the x and y dimensions. + These gradients are then compared between the ground truth and the predicted image and the difference is taken. + The difference used is a smooth L1 norm ( approximate to MAE but differable at zero ) + When used as a loss, its minimization will result in predicted images approaching the same level of sharpness + / blurriness as the ground truth. + + TV+TV2 Regularization with Nonconvex Sparseness-Inducing Penalty for Image Restoration, Chengwu Lu & Hua Huang, 2014 + (http://downloads.hindawi.com/journals/mpe/2014/790547.pdf) + + Parameters: + y_true: The predicted frames at each scale. + y_true: The ground truth frames at each scale + + Return: + The GD loss. + ''' + + assert 4 == K.ndim(y_true) + y_true.set_shape([None,80,80,3]) + y_pred.set_shape([None,80,80,3]) + TV_weight = 1.0 + TV2_weight = 1.0 + loss = 0.0 + + def diff_x(X): + Xleft = X[:, :, 1, :] - X[:, :, 0, :] + Xinner = tf.unstack(X[:, :, 2:, :] - X[:, :, :-2, :], axis=2) + Xright = X[:, :, -1, :] - X[:, :, -2, :] + Xout = [Xleft] + Xinner + [Xright] + Xout = tf.stack(Xout,axis=2) + return Xout * 0.5 + + def diff_y(X): + Xtop = X[:, 1, :, :] - X[:, 0, :, :] + Xinner = tf.unstack(X[:, 2:, :, :] - X[:, :-2, :, :], axis=1) + Xbot = X[:, -1, :, :] - X[:, -2, :, :] + Xout = [Xtop] + Xinner + [Xbot] + Xout = tf.stack(Xout,axis=1) + return Xout * 0.5 + + def diff_xx(X): + Xleft = X[:, :, 1, :] + X[:, :, 0, :] + Xinner = tf.unstack(X[:, :, 2:, :] + X[:, :, :-2, :], axis=2) + Xright = X[:, :, -1, :] + X[:, :, -2, :] + Xout = [Xleft] + Xinner + [Xright] + Xout = tf.stack(Xout,axis=2) + return Xout - 2.0 * X + + def diff_yy(X): + Xtop = X[:, 1, :, :] + X[:, 0, :, :] + Xinner = tf.unstack(X[:, 2:, :, :] + X[:, :-2, :, :], axis=1) + Xbot = X[:, -1, :, :] + X[:, -2, :, :] + Xout = [Xtop] + Xinner + [Xbot] + Xout = tf.stack(Xout,axis=1) + return Xout - 2.0 * X + + def diff_xy(X): + #xout1 + top_left = X[:, 1, 1, :]+X[:, 0, 0, :] + inner_left = tf.unstack(X[:, 2:, 1, :]+X[:, :-2, 0, :], axis=1) + bot_left = X[:, -1, 1, :]+X[:, -2, 0, :] + X_left = [top_left] + inner_left + [bot_left] + X_left = tf.stack(X_left, axis=1) + + top_mid = X[:, 1, 2:, :]+X[:, 0, :-2, :] + mid_mid = tf.unstack(X[:, 2:, 2:, :]+X[:, :-2, :-2, :], axis=1) + bot_mid = X[:, -1, 2:, :]+X[:, -2, :-2, :] + X_mid = [top_mid] + mid_mid + [bot_mid] + X_mid = tf.stack(X_mid, axis=1) + + top_right = X[:, 1, -1, :]+X[:, 0, -2, :] + inner_right = tf.unstack(X[:, 2:, -1, :]+X[:, :-2, -2, :], axis=1) + bot_right = X[:, -1, -1, :]+X[:, -2, -2, :] + X_right = [top_right] + inner_right + [bot_right] + X_right = tf.stack(X_right, axis=1) + + X_mid = tf.unstack(X_mid, axis=2) + Xout1 = [X_left] + X_mid + [X_right] + Xout1 = tf.stack(Xout1, axis=2) + + #Xout2 + top_left = X[:, 0, 1, :]+X[:, 1, 0, :] + inner_left = tf.unstack(X[:, :-2, 1, :]+X[:, 2:, 0, :], axis=1) + bot_left = X[:, -2, 1, :]+X[:, -1, 0, :] + X_left = [top_left] + inner_left + [bot_left] + X_left = tf.stack(X_left, axis=1) + + top_mid = X[:, 0, 2:, :]+X[:, 1, :-2, :] + mid_mid = tf.unstack(X[:, :-2, 2:, :]+X[:, 2:, :-2, :], axis=1) + bot_mid = X[:, -2, 2:, :]+X[:, -1, :-2, :] + X_mid = [top_mid] + mid_mid + [bot_mid] + X_mid = tf.stack(X_mid, axis=1) + + top_right = X[:, 0, -1, :]+X[:, 1, -2, :] + inner_right = tf.unstack(X[:, :-2, -1, :]+X[:, 2:, -2, :], axis=1) + bot_right = X[:, -2, -1, :]+X[:, -1, -2, :] + X_right = [top_right] + inner_right + [bot_right] + X_right = tf.stack(X_right, axis=1) + + X_mid = tf.unstack(X_mid, axis=2) + Xout2 = [X_left] + X_mid + [X_right] + Xout2 = tf.stack(Xout2, axis=2) + + return (Xout1 - Xout2) * 0.25 + + loss += TV_weight * ( generalized_loss_function(diff_x(y_true), diff_x(y_pred), a=1.999999) + + generalized_loss_function(diff_y(y_true), diff_y(y_pred), a=1.999999) ) + + loss += TV2_weight * ( generalized_loss_function(diff_xx(y_true), diff_xx(y_pred), a=1.999999) + + generalized_loss_function(diff_yy(y_true), diff_yy(y_pred), a=1.999999) + + 2.0 * generalized_loss_function(diff_xy(y_true), diff_xy(y_pred), a=1.999999) ) + + return loss / ( TV_weight + TV2_weight ) + + +def scharr_edges(image, magnitude): + ''' + Returns a tensor holding modified Scharr edge maps. + Arguments: + image: Image tensor with shape [batch_size, h, w, d] and type float32. + The image(s) must be 2x2 or larger. + magnitude: Boolean to determine if the edge magnitude or edge direction is returned + Returns: + Tensor holding edge maps for each channel. Returns a tensor with shape + [batch_size, h, w, d, 2] where the last two dimensions hold [[dy[0], dx[0]], + [dy[1], dx[1]], ..., [dy[d-1], dx[d-1]]] calculated using the Scharr filter. + ''' + + # Define vertical and horizontal Scharr filters. + static_image_shape = image.get_shape() + image_shape = tf.shape(image) + ''' + #modified 3x3 Scharr + kernels = [[[-17.0, -61.0, -17.0], [0.0, 0.0, 0.0], [17.0, 61.0, 17.0]], + [[-17.0, 0.0, 17.0], [-61.0, 0.0, 61.0], [-17.0, 0.0, 17.0]]] + ''' + # 5x5 Scharr + kernels = [[[-1.0, -2.0, -3.0, -2.0, -1.0], [-1.0, -2.0, -6.0, -2.0, -1.0], [0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 2.0, 6.0, 2.0, 1.0], [1.0, 2.0, 3.0, 2.0, 1.0]], + [[-1.0, -1.0, 0.0, 1.0, 1.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-3.0, -6.0, 0.0, 6.0, 3.0], [-2.0, -2.0, 0.0, 2.0, 2.0], [-1.0, -1.0, 0.0, 1.0, 1.0]]] + num_kernels = len(kernels) + kernels = numpy.transpose(numpy.asarray(kernels), (1, 2, 0)) + kernels = numpy.expand_dims(kernels, -2) / numpy.sum(numpy.abs(kernels)) + kernels_tf = tf.constant(kernels, dtype=image.dtype) + kernels_tf = tf.tile(kernels_tf, [1, 1, image_shape[-1], 1], name='scharr_filters') + + # Use depth-wise convolution to calculate edge maps per channel. + pad_sizes = [[0, 0], [2, 2], [2, 2], [0, 0]] + padded = tf.pad(image, pad_sizes, mode='REFLECT') + + # Output tensor has shape [batch_size, h, w, d * num_kernels]. + strides = [1, 1, 1, 1] + output = tf.nn.depthwise_conv2d(padded, kernels_tf, strides, 'VALID') + + # Reshape to [batch_size, h, w, d, num_kernels]. + shape = tf.concat([image_shape, [num_kernels]], 0) + output = tf.reshape(output, shape=shape) + output.set_shape(static_image_shape.concatenate([num_kernels])) + + if magnitude: # magnitude of edges + output = tf.sqrt(tf.reduce_sum(tf.square(output),axis=-1)) + else: # direction of edges + output = tf.atan(tf.squeeze(tf.div(output[:,:,:,:,0]/output[:,:,:,:,1]))) + + return output + + +def gmsd_loss(y_true,y_pred): + ''' + Improved image quality metric over MS-SSIM with easier calc + http://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.htm + https://arxiv.org/ftp/arxiv/papers/1308/1308.3052.pdf + ''' + true_edge_mag = scharr_edges(y_true,True) + pred_edge_mag = scharr_edges(y_pred,True) + c = 0.002 + upper = 2.0 * tf.multiply(true_edge_mag,pred_edge_mag) + c + lower = tf.square(true_edge_mag) + tf.square(pred_edge_mag) + c + GMS = tf.div(upper,lower) + _mean, _var = tf.nn.moments(GMS, axes=[1,2], keep_dims=True) + GMSD = tf.reduce_mean(tf.sqrt(_var), axis=-1) # single metric value per image in tensor [?,1,1] + return K.tile(GMSD,[1,64,64]) # need to expand to [?,height,width] dimensions for Keras ... modify to not be hard-coded + + +def ms_ssim(img1, img2, max_val=1.0, power_factors=(0.0517, 0.3295, 0.3462, 0.2726)): + ''' + Computes the MS-SSIM between img1 and img2. + This function assumes that `img1` and `img2` are image batches, i.e. the last + three dimensions are [height, width, channels]. + Note: The true SSIM is only defined on grayscale. This function does not + perform any colorspace transform. (If input is already YUV, then it will + compute YUV SSIM average.) + Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale + structural similarity for image quality assessment." Signals, Systems and + Computers, 2004. + Arguments: + img1: First image batch. + img2: Second image batch. Must have the same rank as img1. + max_val: The dynamic range of the images (i.e., the difference between the + maximum the and minimum allowed values). + power_factors: Iterable of weights for each of the scales. The number of + scales used is the length of the list. Index 0 is the unscaled + resolution's weight and each increasing scale corresponds to the image + being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363, + 0.1333), which are the values obtained in the original paper. + Returns: + A tensor containing an MS-SSIM value for each image in batch. The values + are in range [0, 1]. Returns a tensor with shape: + broadcast(img1.shape[:-3], img2.shape[:-3]). + ''' + + def _verify_compatible_image_shapes(img1, img2): + ''' + Checks if two image tensors are compatible for applying SSIM or PSNR. + This function checks if two sets of images have ranks at least 3, and if the + last three dimensions match. + Args: + img1: Tensor containing the first image batch. + img2: Tensor containing the second image batch. + Returns: + A tuple containing: the first tensor shape, the second tensor shape, and a + list of control_flow_ops.Assert() ops implementing the checks. + Raises: + ValueError: When static shape check fails. + ''' + shape1 = img1.get_shape().with_rank_at_least(3) + shape2 = img2.get_shape().with_rank_at_least(3) + shape1[-3:].assert_is_compatible_with(shape2[-3:]) + + if shape1.ndims is not None and shape2.ndims is not None: + for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])): + if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)): + raise ValueError('Two images are not compatible: %s and %s' % (shape1, shape2)) + + # Now assign shape tensors. + shape1, shape2 = tf.shape_n([img1, img2]) + + # TODO(sjhwang): Check if shape1[:-3] and shape2[:-3] are broadcastable. + checks = [] + checks.append(tf.Assert(tf.greater_equal(tf.size(shape1), 3),[shape1, shape2], summarize=10)) + checks.append(tf.Assert(tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])),[shape1, shape2], summarize=10)) + + return shape1, shape2, checks + + def _ssim_per_channel(img1, img2, max_val=1.0): + ''' + Computes SSIM index between img1 and img2 per color channel. + This function matches the standard SSIM implementation from: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image + quality assessment: from error visibility to structural similarity. IEEE + transactions on image processing. + Details: + - 11x11 Gaussian filter of width 1.5 is used. + - k1 = 0.01, k2 = 0.03 as in the original paper. + Args: + img1: First image batch. + img2: Second image batch. + max_val: The dynamic range of the images (i.e., the difference between the + maximum the and minimum allowed values). + Returns: + A pair of tensors containing and channel-wise SSIM and contrast-structure + values. The shape is [..., channels]. + ''' + + def _fspecial_gauss(size, sigma): + ''' + Function to mimic the 'fspecial' gaussian MATLAB function. + ''' + size = tf.convert_to_tensor(size, 'int32') + sigma = tf.convert_to_tensor(sigma) + + coords = tf.cast(tf.range(size), sigma.dtype) + coords -= tf.cast(size - 1, sigma.dtype) / 2.0 + + g = tf.square(coords) + g *= -0.5 / tf.square(sigma) + + g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1]) + g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax(). + g = tf.nn.softmax(g) + return tf.reshape(g, shape=[size, size, 1, 1]) + + def _ssim_helper(x, y, max_val, kernel, compensation=1.0): + ''' + Helper function for computing SSIM. + SSIM estimates covariances with weighted sums. The default parameters + use a biased estimate of the covariance: + Suppose `reducer` is a weighted sum, then the mean estimators are + \mu_x = \sum_i w_i x_i, + \mu_y = \sum_i w_i y_i, + where w_i's are the weighted-sum weights, and covariance estimator is + cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) + with assumption \sum_i w_i = 1. This covariance estimator is biased, since + E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y). + For SSIM measure with unbiased covariance estimators, pass as `compensation` + argument (1 - \sum_i w_i ^ 2). + Arguments: + x: First set of images. + y: Second set of images. + reducer: Function that computes 'local' averages from set of images. + For non-covolutional version, this is usually tf.reduce_mean(x, [1, 2]), + and for convolutional version, this is usually tf.nn.avg_pool or + tf.nn.conv2d with weighted-sum kernel. + max_val: The dynamic range (i.e., the difference between the maximum + possible allowed value and the minimum allowed value). + compensation: Compensation factor. See above. + Returns: + A pair containing the luminance measure, and the contrast-structure measure. + ''' + + def reducer(x, kernel): + shape = tf.shape(x) + x = tf.reshape(x, shape=tf.concat([[-1], shape[-3:]], 0)) + y = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID') + return tf.reshape(y, tf.concat([shape[:-3],tf.shape(y)[1:]], 0)) + + _SSIM_K1 = 0.01 + _SSIM_K2 = 0.03 + + c1 = (_SSIM_K1 * max_val) ** 2 + c2 = (_SSIM_K2 * max_val) ** 2 + + # SSIM luminance measure is + # (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1). + mean0 = reducer(x, kernel) + mean1 = reducer(y, kernel) + num0 = mean0 * mean1 * 2.0 + den0 = tf.square(mean0) + tf.square(mean1) + luminance = (num0 + c1) / (den0 + c1) + + # SSIM contrast-structure measure is + # (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2). + # Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then + # cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y) + # = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j). + num1 = reducer(x * y, kernel) * 2.0 + den1 = reducer(tf.square(x) + tf.square(y), kernel) + c2 *= compensation + cs = (num1 - num0 + c2) / (den1 - den0 + c2) + + # SSIM score is the product of the luminance and contrast-structure measures. + return luminance, cs + + filter_size = tf.constant(9, dtype='int32') # changed from 11 to 9 due + filter_sigma = tf.constant(1.5, dtype=img1.dtype) + + shape1, shape2 = tf.shape_n([img1, img2]) + checks = [tf.Assert(tf.reduce_all(tf.greater_equal(shape1[-3:-1], filter_size)),[shape1, filter_size], summarize=8), + tf.Assert(tf.reduce_all(tf.greater_equal(shape2[-3:-1], filter_size)),[shape2, filter_size], summarize=8)] + + # Enforce the check to run before computation. + with tf.control_dependencies(checks): + img1 = tf.identity(img1) + + # TODO(sjhwang): Try to cache kernels and compensation factor. + kernel = _fspecial_gauss(filter_size, filter_sigma) + kernel = tf.tile(kernel, multiples=[1, 1, shape1[-1], 1]) + + # The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`, + # but to match MATLAB implementation of MS-SSIM, we use 1.0 instead. + compensation = 1.0 + + # TODO(sjhwang): Try FFT. + # TODO(sjhwang): Gaussian kernel is separable in space. Consider applying + # 1-by-n and n-by-1 Gaussain filters instead of an n-by-n filter. + + luminance, cs = _ssim_helper(img1, img2, max_val, kernel, compensation) + + # Average over the second and the third from the last: height, width. + axes = tf.constant([-3, -2], dtype='int32') + ssim_val = tf.reduce_mean(luminance * cs, axes) + cs = tf.reduce_mean(cs, axes) + return ssim_val, cs + + def do_pad(images, remainder): + padding = tf.expand_dims(remainder, -1) + padding = tf.pad(padding, [[1, 0], [1, 0]]) + return [tf.pad(x, padding, mode='SYMMETRIC') for x in images] + + # Shape checking. + shape1 = img1.get_shape().with_rank_at_least(3) + shape2 = img2.get_shape().with_rank_at_least(3) + shape1[-3:].merge_with(shape2[-3:]) + + with tf.name_scope(None, 'MS-SSIM', [img1, img2]): + shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2) + with tf.control_dependencies(checks): + img1 = tf.identity(img1) + + # Need to convert the images to float32. Scale max_val accordingly so that + # SSIM is computed correctly. + max_val = tf.cast(max_val, img1.dtype) + max_val = tf.image.convert_image_dtype(max_val, 'float32') + img1 = tf.image.convert_image_dtype(img1, 'float32') + img2 = tf.image.convert_image_dtype(img2, 'float32') + + imgs = [img1, img2] + shapes = [shape1, shape2] + + # img1 and img2 are assumed to be a (multi-dimensional) batch of + # 3-dimensional images (height, width, channels). `heads` contain the batch + # dimensions, and `tails` contain the image dimensions. + heads = [s[:-3] for s in shapes] + tails = [s[-3:] for s in shapes] + + divisor = [1, 2, 2, 1] + divisor_tensor = tf.constant(divisor[1:], dtype='int32') + + mcs = [] + for k in range(len(power_factors)): + with tf.name_scope(None, 'Scale%d' % k, imgs): + if k > 0: + # Avg pool takes rank 4 tensors. Flatten leading dimensions. + flat_imgs = [tf.reshape(x, tf.concat([[-1], t], 0)) for x, t in zip(imgs, tails)] + + remainder = tails[0] % divisor_tensor + need_padding = tf.reduce_any(tf.not_equal(remainder, 0)) + padded = tf.cond(need_padding,lambda: do_pad(flat_imgs, remainder), + lambda: flat_imgs) + + downscaled = [tf.nn.avg_pool(x, ksize=divisor, strides=divisor, padding='VALID') + for x in padded] + tails = [x[1:] for x in tf.shape_n(downscaled)] + imgs = [tf.reshape(x, tf.concat([h, t], 0)) for x, h, t in zip(downscaled, heads, tails)] + + # Overwrite previous ssim value since we only need the last one. + ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val) + mcs.append(tf.nn.relu(cs)) + + # Remove the cs score for the last scale. In the MS-SSIM calculation, + # we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p). + mcs.pop() # Remove the cs score for the last scale. + mcs_and_ssim = tf.stack(mcs + [tf.nn.relu(ssim_per_channel)],axis=-1) + # Take weighted geometric mean across the scale axis. + ms_ssim = tf.reduce_prod(tf.pow(mcs_and_ssim, power_factors),[-1]) + + return tf.reduce_mean(ms_ssim, [-1]) # Avg over color channels. + + +def ms_ssim_loss(y_true,y_pred): + MSSSIM = K.expand_dims(K.expand_dims(1.0 - ms_ssim(y_true, y_pred),axis=-1), axis=-1) + return K.tile(MSSSIM,[1,64,64]) # need to expand to [1,height,width] dimensions for Keras ... modify to not be hard-coded diff --git a/lib/model/masks.py b/lib/model/masks.py new file mode 100644 index 0000000000..6fc6ee151b --- /dev/null +++ b/lib/model/masks.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +""" Masks functions for faceswap.py + Masks from: + dfaker: https://github.com/dfaker/df""" + +import logging + +import cv2 +import numpy as np + +from lib.umeyama import umeyama + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def dfaker(landmarks, face, channels=4): + """ Dfaker model mask + Embeds the mask into the face alpha channel + + channels: 1, 3 or 4: + 1 - Return a single channel mask + 3 - Return a 3 channel mask + 4 - Return the original image with the mask in the alpha channel + """ + padding = int(face.shape[0] * 0.1875) + coverage = face.shape[0] - (padding * 2) + logger.trace("face_shape: %s, coverage: %s, landmarks: %s", face.shape, coverage, landmarks) + + mat = umeyama(landmarks[17:], True)[0:2] + mat = np.array(mat.ravel()).reshape(2, 3) + mat = mat * coverage + mat[:, 2] += padding + + points = np.array(landmarks).reshape((-1, 2)) + facepoints = np.array(points).reshape((-1, 2)) + + mask = np.zeros_like(face, dtype=np.uint8) + + hull = cv2.convexHull(facepoints.astype(int)) # pylint: disable=no-member + hull = cv2.transform(hull.reshape(1, -1, 2), # pylint: disable=no-member + mat).reshape(-1, 2).astype(int) + cv2.fillConvexPoly(mask, hull, (255, 255, 255)) # pylint: disable=no-member + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) # pylint: disable=no-member + mask = cv2.dilate(mask, # pylint: disable=no-member + kernel, + iterations=1, + borderType=cv2.BORDER_REFLECT) # pylint: disable=no-member + mask = mask[:, :, :1] + + return merge_mask(face, mask, channels) + + +def dfl_full(landmarks, face, channels=4): + """ DFL Face Full Mask + + channels: 1, 3 or 4: + 1 - Return a single channel mask + 3 - Return a 3 channel mask + 4 - Return the original image with the mask in the alpha channel + """ + logger.trace("face_shape: %s, landmarks: %s", face.shape, landmarks) + mask = np.zeros(face.shape[0:2] + (1, ), dtype=np.float32) + jaw = cv2.convexHull(np.concatenate(( # pylint: disable=no-member + landmarks[0:17], # jawline + landmarks[48:68], # mouth + [landmarks[0]], # temple + [landmarks[8]], # chin + [landmarks[16]]))) # temple + nose_ridge = cv2.convexHull(np.concatenate(( # pylint: disable=no-member + landmarks[27:31], # nose line + [landmarks[33]]))) # nose point + eyes = cv2.convexHull(np.concatenate(( # pylint: disable=no-member + landmarks[17:27], # eyebrows + [landmarks[0]], # temple + [landmarks[27]], # nose top + [landmarks[16]], # temple + [landmarks[33]]))) # nose point + + cv2.fillConvexPoly(mask, jaw, (255, 255, 255)) # pylint: disable=no-member + cv2.fillConvexPoly(mask, nose_ridge, (255, 255, 255)) # pylint: disable=no-member + cv2.fillConvexPoly(mask, eyes, (255, 255, 255)) # pylint: disable=no-member + return merge_mask(face, mask, channels) + + +def merge_mask(image, mask, channels): + """ Return the mask in requested shape """ + logger.trace("image_shape: %s, mask_shape: %s, channels: %s", + image.shape, mask.shape, channels) + assert channels in (1, 3, 4), "Channels should be 1, 3 or 4" + assert mask.shape[2] == 1 and mask.ndim == 3, "Input mask be 3 dimensions with 1 channel" + + if channels == 3: + retval = np.tile(mask, 3) + elif channels == 4: + retval = np.concatenate((image, mask), -1) + else: + retval = mask + + logger.trace("Final mask shape: %s", retval.shape) + return retval diff --git a/lib/model/nn_blocks.py b/lib/model/nn_blocks.py new file mode 100644 index 0000000000..b97545d987 --- /dev/null +++ b/lib/model/nn_blocks.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +""" Neural Network Blocks for faceswap.py + Blocks from: + the original https://www.reddit.com/r/deepfakes/ code sample + contribs + dfaker: https://github.com/dfaker/df + shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" + +import logging +import tensorflow as tf +import keras.backend as K + +from keras.layers import (add, Add, BatchNormalization, concatenate, Lambda, regularizers, + Permute, Reshape, SeparableConv2D, Softmax, UpSampling2D) +from keras.layers.advanced_activations import LeakyReLU +from keras.layers.convolutional import Conv2D +from keras.layers.core import Activation +from keras.initializers import he_uniform, Constant +from .initializers import ICNR +from .layers import PixelShuffler, Scale, SubPixelUpscaling, ReflectionPadding2D +from .normalization import GroupNormalization, InstanceNormalization + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class NNBlocks(): + """ Blocks to use for creating models """ + def __init__(self, use_subpixel=False, use_icnr_init=False, use_reflect_padding=False): + logger.debug("Initializing %s: (use_subpixel: %s, use_icnr_init: %s, use_reflect_padding: %s", + self.__class__.__name__, use_subpixel, use_icnr_init, use_reflect_padding) + self.use_subpixel = use_subpixel + self.use_icnr_init = use_icnr_init + self.use_reflect_padding = use_reflect_padding + logger.debug("Initialized %s", self.__class__.__name__) + + @staticmethod + def update_kwargs(kwargs): + """ Set the default kernel initializer to he_uniform() """ + kwargs["kernel_initializer"] = kwargs.get("kernel_initializer", he_uniform()) + return kwargs + + # <<< Original Model Blocks >>> # + def conv(self, inp, filters, kernel_size=5, strides=2, padding='same', use_instance_norm=False, res_block_follows=False, **kwargs): + """ Convolution Layer""" + logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, use_instance_norm: %s, " + "kwargs: %s", inp, filters, kernel_size, strides, use_instance_norm, kwargs) + kwargs = self.update_kwargs(kwargs) + if self.use_reflect_padding: + inp = ReflectionPadding2D(stride=strides, kernel_size=kernel_size)(inp) + padding = 'valid' + var_x = Conv2D(filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + **kwargs)(inp) + if use_instance_norm: + var_x = InstanceNormalization()(var_x) + if not res_block_follows: + var_x = LeakyReLU(0.1)(var_x) + return var_x + + def upscale(self, inp, filters, kernel_size=3, padding= 'same', use_instance_norm=False, res_block_follows=False, **kwargs): + """ Upscale Layer """ + logger.debug("inp: %s, filters: %s, kernel_size: %s, use_instance_norm: %s, kwargs: %s", + inp, filters, kernel_size, use_instance_norm, kwargs) + kwargs = self.update_kwargs(kwargs) + if self.use_reflect_padding: + inp = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(inp) + padding = 'valid' + if self.use_icnr_init: + kwargs["kernel_initializer"] = ICNR(initializer=kwargs["kernel_initializer"]) + var_x = Conv2D(filters * 4, + kernel_size=kernel_size, + padding=padding, + **kwargs)(inp) + if use_instance_norm: + var_x = InstanceNormalization()(var_x) + if not res_block_follows: + var_x = LeakyReLU(0.1)(var_x) + if self.use_subpixel: + var_x = SubPixelUpscaling()(var_x) + else: + var_x = PixelShuffler()(var_x) + return var_x + + # <<< DFaker Model Blocks >>> # + def res_block(self, inp, filters, kernel_size=3, padding= 'same', **kwargs): + """ Residual block """ + logger.debug("inp: %s, filters: %s, kernel_size: %s, kwargs: %s", + inp, filters, kernel_size, kwargs) + kwargs = self.update_kwargs(kwargs) + var_x = LeakyReLU(alpha=0.2)(inp) + if self.use_reflect_padding: + var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x) + padding = 'valid' + var_x = Conv2D(filters, + kernel_size=kernel_size, + padding=padding, + **kwargs)(var_x) + var_x = LeakyReLU(alpha=0.2)(var_x) + if self.use_reflect_padding: + var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x) + padding = 'valid' + var_x = Conv2D(filters, + kernel_size=kernel_size, + padding=padding, + **kwargs)(var_x) + var_x = Scale(gamma_init=Constant(value=0.1))(var_x) + var_x = Add()([var_x, inp]) + var_x = LeakyReLU(alpha=0.2)(var_x) + return var_x + + # <<< Unbalanced Model Blocks >>> # + def conv_sep(self, inp, filters, kernel_size=5, strides=2, **kwargs): + """ Seperable Convolution Layer """ + logger.debug("inp: %s, filters: %s, kernel_size: %s, strides: %s, kwargs: %s", + inp, filters, kernel_size, strides, kwargs) + kwargs = self.update_kwargs(kwargs) + var_x = SeparableConv2D(filters, + kernel_size=kernel_size, + strides=strides, + padding='same', + **kwargs)(inp) + var_x = Activation("relu")(var_x) + return var_x + +# <<< GAN V2.2 Blocks >>> # +# TODO Merge these into NNBLock class when porting GAN2.2 + + +# Gan Constansts: +GAN22_CONV_INIT = "he_normal" +GAN22_REGULARIZER = 1e-4 + + +# Gan Blocks: +def normalization(inp, norm='none', group='16'): + """ GAN Normalization """ + if norm == 'layernorm': + var_x = GroupNormalization(group=group)(inp) + elif norm == 'batchnorm': + var_x = BatchNormalization()(inp) + elif norm == 'groupnorm': + var_x = GroupNormalization(group=16)(inp) + elif norm == 'instancenorm': + var_x = InstanceNormalization()(inp) + elif norm == 'hybrid': + if group % 2 == 1: + raise ValueError("Output channels must be an even number for hybrid norm, " + "received {}.".format(group)) + filt = group + var_x_0 = Lambda(lambda var_x: var_x[..., :filt // 2])(var_x) + var_x_1 = Lambda(lambda var_x: var_x[..., filt // 2:])(var_x) + var_x_0 = Conv2D(filt // 2, + kernel_size=1, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer=GAN22_CONV_INIT)(var_x_0) + var_x_1 = InstanceNormalization()(var_x_1) + var_x = concatenate([var_x_0, var_x_1], axis=-1) + else: + var_x = inp + return var_x + + +def upscale_ps(inp, filters, initializer, use_norm=False, norm="none"): + """ GAN Upscaler - Pixel Shuffler """ + var_x = Conv2D(filters * 4, + kernel_size=3, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer=initializer, + padding="same")(inp) + var_x = LeakyReLU(0.2)(var_x) + var_x = normalization(var_x, norm, filters) if use_norm else var_x + var_x = PixelShuffler()(var_x) + return var_x + + +def upscale_nn(inp, filters, use_norm=False, norm="none"): + """ GAN Neural Network """ + var_x = UpSampling2D()(inp) + var_x = reflect_padding_2d(var_x, 1) + var_x = Conv2D(filters, + kernel_size=3, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer="he_normal")(var_x) + var_x = normalization(var_x, norm, filters) if use_norm else var_x + return var_x + + +def reflect_padding_2d(inp, pad=1): + """ GAN Reflect Padding (2D) """ + var_x = Lambda(lambda var_x: tf.pad(var_x, + [[0, 0], [pad, pad], [pad, pad], [0, 0]], + mode="REFLECT"))(inp) + return var_x + + +def conv_gan(inp, filters, use_norm=False, strides=2, norm='none'): + """ GAN Conv Block """ + var_x = Conv2D(filters, + kernel_size=3, + strides=strides, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer=GAN22_CONV_INIT, + use_bias=False, + padding="same")(inp) + var_x = Activation("relu")(var_x) + var_x = normalization(var_x, norm, filters) if use_norm else var_x + return var_x + + +def conv_d_gan(inp, filters, use_norm=False, norm='none'): + """ GAN Discriminator Conv Block """ + var_x = inp + var_x = Conv2D(filters, + kernel_size=4, + strides=2, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer=GAN22_CONV_INIT, + use_bias=False, + padding="same")(var_x) + var_x = LeakyReLU(alpha=0.2)(var_x) + var_x = normalization(var_x, norm, filters) if use_norm else var_x + return var_x + + +def res_block_gan(inp, filters, use_norm=False, norm='none'): + """ GAN Res Block """ + var_x = Conv2D(filters, + kernel_size=3, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer=GAN22_CONV_INIT, + use_bias=False, + padding="same")(inp) + var_x = LeakyReLU(alpha=0.2)(var_x) + var_x = normalization(var_x, norm, filters) if use_norm else var_x + var_x = Conv2D(filters, + kernel_size=3, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER), + kernel_initializer=GAN22_CONV_INIT, + use_bias=False, + padding="same")(var_x) + var_x = add([var_x, inp]) + var_x = LeakyReLU(alpha=0.2)(var_x) + var_x = normalization(var_x, norm, filters) if use_norm else var_x + return var_x + + +def self_attn_block(inp, n_c, squeeze_factor=8): + """ GAN Self Attention Block + Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow + """ + msg = "Input channels must be >= {}, recieved nc={}".format(squeeze_factor, n_c) + assert n_c // squeeze_factor > 0, msg + var_x = inp + shape_x = var_x.get_shape().as_list() + + var_f = Conv2D(n_c // squeeze_factor, 1, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x) + var_g = Conv2D(n_c // squeeze_factor, 1, + kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x) + var_h = Conv2D(n_c, 1, kernel_regularizer=regularizers.l2(GAN22_REGULARIZER))(var_x) + + shape_f = var_f.get_shape().as_list() + shape_g = var_g.get_shape().as_list() + shape_h = var_h.get_shape().as_list() + flat_f = Reshape((-1, shape_f[-1]))(var_f) + flat_g = Reshape((-1, shape_g[-1]))(var_g) + flat_h = Reshape((-1, shape_h[-1]))(var_h) + + var_s = Lambda(lambda var_x: K.batch_dot(var_x[0], + Permute((2, 1))(var_x[1])))([flat_g, flat_f]) + + beta = Softmax(axis=-1)(var_s) + var_o = Lambda(lambda var_x: K.batch_dot(var_x[0], var_x[1]))([beta, flat_h]) + var_o = Reshape(shape_x[1:])(var_o) + var_o = Scale()(var_o) + + out = add([var_o, inp]) + return out diff --git a/lib/model/normalization.py b/lib/model/normalization.py new file mode 100644 index 0000000000..ec4dbb1f5e --- /dev/null +++ b/lib/model/normalization.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" Normaliztion methods for faceswap.py + Code from: + shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" + +import sys +import inspect + +from keras.engine import Layer, InputSpec +from keras import initializers, regularizers, constraints +from keras import backend as K +from keras.utils.generic_utils import get_custom_objects + + +def to_list(inp): + """ Convert to list """ + if not isinstance(inp, (list, tuple)): + return [inp] + return list(inp) + + +class InstanceNormalization(Layer): + """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). + Normalize the activations of the previous layer at each step, + i.e. applies a transformation that maintains the mean activation + close to 0 and the activation standard deviation close to 1. + # Arguments + axis: Integer, the axis that should be normalized + (typically the features axis). + For instance, after a `Conv2D` layer with + `data_format="channels_first"`, + set `axis=1` in `InstanceNormalization`. + Setting `axis=None` will normalize all values in each instance of the batch. + Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), + this can be disabled since the scaling + will be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + # Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + # Output shape + Same shape as input. + # References + - [Layer Normalization](https://arxiv.org/abs/1607.06450) + - [Instance Normalization: The Missing Ingredient for Fast + Stylization](https://arxiv.org/abs/1607.08022) + """ + def __init__(self, + axis=None, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs): + self.beta = None + self.gamma = None + super(InstanceNormalization, self).__init__(**kwargs) + self.supports_masking = True + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = initializers.get(beta_initializer) + self.gamma_initializer = initializers.get(gamma_initializer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_constraint = constraints.get(beta_constraint) + self.gamma_constraint = constraints.get(gamma_constraint) + + def build(self, input_shape): + ndim = len(input_shape) + if self.axis == 0: + raise ValueError('Axis cannot be zero') + + if (self.axis is not None) and (ndim == 2): + raise ValueError('Cannot specify axis for rank 1 tensor') + + self.input_spec = InputSpec(ndim=ndim) + + if self.axis is None: + shape = (1,) + else: + shape = (input_shape[self.axis],) + + if self.scale: + self.gamma = self.add_weight(shape=shape, + name='gamma', + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint) + else: + self.gamma = None + if self.center: + self.beta = self.add_weight(shape=shape, + name='beta', + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint) + else: + self.beta = None + self.built = True + + def call(self, inputs, training=None): + input_shape = K.int_shape(inputs) + reduction_axes = list(range(0, len(input_shape))) + + if self.axis is not None: + del reduction_axes[self.axis] + + del reduction_axes[0] + + mean = K.mean(inputs, reduction_axes, keepdims=True) + stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon + normed = (inputs - mean) / stddev + + broadcast_shape = [1] * len(input_shape) + if self.axis is not None: + broadcast_shape[self.axis] = input_shape[self.axis] + + if self.scale: + broadcast_gamma = K.reshape(self.gamma, broadcast_shape) + normed = normed * broadcast_gamma + if self.center: + broadcast_beta = K.reshape(self.beta, broadcast_shape) + normed = normed + broadcast_beta + return normed + + def get_config(self): + config = { + 'axis': self.axis, + 'epsilon': self.epsilon, + 'center': self.center, + 'scale': self.scale, + 'beta_initializer': initializers.serialize(self.beta_initializer), + 'gamma_initializer': initializers.serialize(self.gamma_initializer), + 'beta_regularizer': regularizers.serialize(self.beta_regularizer), + 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), + 'beta_constraint': constraints.serialize(self.beta_constraint), + 'gamma_constraint': constraints.serialize(self.gamma_constraint) + } + base_config = super(InstanceNormalization, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class GroupNormalization(Layer): + """ Group Normalization + from: shoanlu GAN: https://github.com/shaoanlu/faceswap-GAN""" + + def __init__(self, axis=-1, + gamma_init='one', beta_init='zero', + gamma_regularizer=None, beta_regularizer=None, + epsilon=1e-6, + group=32, + data_format=None, + **kwargs): + self.beta = None + self.gamma = None + super(GroupNormalization, self).__init__(**kwargs) + + self.axis = to_list(axis) + self.gamma_init = initializers.get(gamma_init) + self.beta_init = initializers.get(beta_init) + self.gamma_regularizer = regularizers.get(gamma_regularizer) + self.beta_regularizer = regularizers.get(beta_regularizer) + self.epsilon = epsilon + self.group = group + self.data_format = K.normalize_data_format(data_format) + + self.supports_masking = True + + def build(self, input_shape): + self.input_spec = [InputSpec(shape=input_shape)] + shape = [1 for _ in input_shape] + if self.data_format == 'channels_last': + channel_axis = -1 + shape[channel_axis] = input_shape[channel_axis] + elif self.data_format == 'channels_first': + channel_axis = 1 + shape[channel_axis] = input_shape[channel_axis] + # for i in self.axis: + # shape[i] = input_shape[i] + self.gamma = self.add_weight(shape=shape, + initializer=self.gamma_init, + regularizer=self.gamma_regularizer, + name='gamma') + self.beta = self.add_weight(shape=shape, + initializer=self.beta_init, + regularizer=self.beta_regularizer, + name='beta') + self.built = True + + def call(self, inputs, mask=None): + input_shape = K.int_shape(inputs) + if len(input_shape) != 4 and len(input_shape) != 2: + raise ValueError('Inputs should have rank ' + + str(4) + " or " + str(2) + + '; Received input shape:', str(input_shape)) + + if len(input_shape) == 4: + if self.data_format == 'channels_last': + batch_size, height, width, channels = input_shape + if batch_size is None: + batch_size = -1 + + if channels < self.group: + raise ValueError('Input channels should be larger than group size' + + '; Received input channels: ' + str(channels) + + '; Group size: ' + str(self.group)) + + var_x = K.reshape(inputs, (batch_size, + height, + width, + self.group, + channels // self.group)) + mean = K.mean(var_x, axis=[1, 2, 4], keepdims=True) + std = K.sqrt(K.var(var_x, axis=[1, 2, 4], keepdims=True) + self.epsilon) + var_x = (var_x - mean) / std + + var_x = K.reshape(var_x, (batch_size, height, width, channels)) + retval = self.gamma * var_x + self.beta + elif self.data_format == 'channels_first': + batch_size, channels, height, width = input_shape + if batch_size is None: + batch_size = -1 + + if channels < self.group: + raise ValueError('Input channels should be larger than group size' + + '; Received input channels: ' + str(channels) + + '; Group size: ' + str(self.group)) + + var_x = K.reshape(inputs, (batch_size, + self.group, + channels // self.group, + height, + width)) + mean = K.mean(var_x, axis=[2, 3, 4], keepdims=True) + std = K.sqrt(K.var(var_x, axis=[2, 3, 4], keepdims=True) + self.epsilon) + var_x = (var_x - mean) / std + + var_x = K.reshape(var_x, (batch_size, channels, height, width)) + retval = self.gamma * var_x + self.beta + + elif len(input_shape) == 2: + reduction_axes = list(range(0, len(input_shape))) + del reduction_axes[0] + batch_size, _ = input_shape + if batch_size is None: + batch_size = -1 + + mean = K.mean(inputs, keepdims=True) + std = K.sqrt(K.var(inputs, keepdims=True) + self.epsilon) + var_x = (inputs - mean) / std + + retval = self.gamma * var_x + self.beta + return retval + + def get_config(self): + config = {'epsilon': self.epsilon, + 'axis': self.axis, + 'gamma_init': initializers.serialize(self.gamma_init), + 'beta_init': initializers.serialize(self.beta_init), + 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), + 'beta_regularizer': regularizers.serialize(self.gamma_regularizer), + 'group': self.group} + base_config = super(GroupNormalization, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +# Update normalizations into Keras custom objects +for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and obj.__module__ == __name__: + get_custom_objects().update({name: obj}) diff --git a/lib/multithreading.py b/lib/multithreading.py index 3f4ede12e0..55e9bb0cbc 100644 --- a/lib/multithreading.py +++ b/lib/multithreading.py @@ -117,6 +117,8 @@ def run(self): self._target(*self._args, **self._kwargs) except Exception: # pylint: disable=broad-except self.err = sys.exc_info() + logger.debug("Error in thread (%s): %s", self._name, + self.err[1].with_traceback(self.err[2])) finally: # Avoid a refcycle if the thread is running a function with # an argument that has a member that points to the thread. @@ -126,8 +128,8 @@ def run(self): class MultiThread(): """ Threading for IO heavy ops Catches errors in thread and rethrows to parent """ - def __init__(self, target, *args, thread_count=1, **kwargs): - self._name = target.__name__ + def __init__(self, target, *args, thread_count=1, name=None, **kwargs): + self._name = name if name else target.__name__ logger.debug("Initializing %s: (target: '%s', thread_count: %s)", self.__class__.__name__, self._name, thread_count) logger.trace("args: %s, kwargs: %s", args, kwargs) @@ -139,6 +141,16 @@ def __init__(self, target, *args, thread_count=1, **kwargs): self._kwargs = kwargs logger.debug("Initialized %s: '%s'", self.__class__.__name__, self._name) + @property + def has_error(self): + """ Return true if a thread has errored, otherwise false """ + return any(thread.err for thread in self._threads) + + @property + def errors(self): + """ Return a list of thread errors """ + return [thread.err for thread in self._threads] + def start(self): """ Start a thread with the given method and args """ logger.debug("Starting thread(s): '%s'", self._name) diff --git a/lib/queue_manager.py b/lib/queue_manager.py index 2063736630..6acbfd05cb 100644 --- a/lib/queue_manager.py +++ b/lib/queue_manager.py @@ -9,7 +9,7 @@ import sys import threading -from queue import Empty as QueueEmpty # pylint: disable=unused-import; # noqa +from queue import Queue, Empty as QueueEmpty # pylint: disable=unused-import; # noqa from time import sleep logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -37,7 +37,7 @@ def __init__(self): self._log_queue = self.manager.Queue() if "gui" not in sys.argv else mp.Queue() logger.debug("Initialized %s", self.__class__.__name__) - def add_queue(self, name, maxsize=0): + def add_queue(self, name, maxsize=0, multiprocessing_queue=True): """ Add a queue to the manager Adds an event "shutdown" to the queue that can be used to indicate @@ -46,7 +46,12 @@ def add_queue(self, name, maxsize=0): logger.debug("QueueManager adding: (name: '%s', maxsize: %s)", name, maxsize) if name in self.queues.keys(): raise ValueError("Queue '{}' already exists.".format(name)) - queue = self.manager.Queue(maxsize=maxsize) + + if multiprocessing_queue: + queue = self.manager.Queue(maxsize=maxsize) + else: + queue = Queue(maxsize=maxsize) + setattr(queue, "shutdown", self.shutdown) self.queues[name] = queue logger.debug("QueueManager added: (name: '%s')", name) diff --git a/lib/training_data.py b/lib/training_data.py index fad52eec73..32037c9245 100644 --- a/lib/training_data.py +++ b/lib/training_data.py @@ -1,106 +1,401 @@ -from random import shuffle -import cv2 -import numpy - -from .multithreading import BackgroundGenerator -from .umeyama import umeyama - -class TrainingDataGenerator(): - def __init__(self, random_transform_args, coverage, scale=5, zoom=1): #TODO thos default should stay in the warp function - self.random_transform_args = random_transform_args - self.coverage = coverage - self.scale = scale - self.zoom = zoom - - def minibatchAB(self, images, batchsize, doShuffle=True): - batch = BackgroundGenerator(self.minibatch(images, batchsize, doShuffle), 1) - for ep1, warped_img, target_img in batch.iterator(): - yield ep1, warped_img, target_img - - # A generator function that yields epoch, batchsize of warped_img and batchsize of target_img - def minibatch(self, data, batchsize, doShuffle=True): - length = len(data) - assert length >= batchsize, "Number of images is lower than batch-size (Note that too few images may lead to bad training). # images: {}, batch-size: {}".format(length, batchsize) - epoch = i = 0 - if doShuffle: - shuffle(data) - while True: - size = batchsize - if i+size > length: - if doShuffle: - shuffle(data) - i = 0 - epoch+=1 - rtn = numpy.float32([self.read_image(img) for img in data[i:i+size]]) - i+=size - yield epoch, rtn[:,0,:,:,:], rtn[:,1,:,:,:] - - def color_adjust(self, img): - return img / 255.0 - - def read_image(self, fn): - try: - image = self.color_adjust(cv2.imread(fn)) - except TypeError: - raise Exception("Error while reading image", fn) - - image = cv2.resize(image, (256,256)) - image = self.random_transform( image, **self.random_transform_args ) - warped_img, target_img = self.random_warp( image, self.coverage, self.scale, self.zoom ) - - return warped_img, target_img - - def random_transform(self, image, rotation_range, zoom_range, shift_range, random_flip): - h, w = image.shape[0:2] - rotation = numpy.random.uniform(-rotation_range, rotation_range) - scale = numpy.random.uniform(1 - zoom_range, 1 + zoom_range) - tx = numpy.random.uniform(-shift_range, shift_range) * w - ty = numpy.random.uniform(-shift_range, shift_range) * h - mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale) - mat[:, 2] += (tx, ty) - result = cv2.warpAffine( - image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE) - if numpy.random.random() < random_flip: - result = result[:, ::-1] - return result - - # get pair of random warped images from aligned face image - def random_warp(self, image, coverage, scale = 5, zoom = 1): - assert image.shape == (256, 256, 3) - range_ = numpy.linspace(128 - coverage//2, 128 + coverage//2, 5) - mapx = numpy.broadcast_to(range_, (5, 5)) - mapy = mapx.T - - mapx = mapx + numpy.random.normal(size=(5,5), scale=scale) - mapy = mapy + numpy.random.normal(size=(5,5), scale=scale) - - interp_mapx = cv2.resize(mapx, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32') - interp_mapy = cv2.resize(mapy, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32') - - warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR) - - src_points = numpy.stack([mapx.ravel(), mapy.ravel() ], axis=-1) - dst_points = numpy.mgrid[0:65*zoom:16*zoom,0:65*zoom:16*zoom].T.reshape(-1,2) - mat = umeyama(src_points, dst_points, True)[0:2] - - target_image = cv2.warpAffine(image, mat, (64*zoom,64*zoom)) - - return warped_image, target_image - -def stack_images(images): - def get_transpose_axes(n): - if n % 2 == 0: - y_axes = list(range(1, n - 1, 2)) - x_axes = list(range(0, n - 1, 2)) - else: - y_axes = list(range(0, n - 1, 2)) - x_axes = list(range(1, n - 1, 2)) - return y_axes, x_axes, [n - 1] - - images_shape = numpy.array(images.shape) - new_axes = get_transpose_axes(len(images_shape)) - new_shape = [numpy.prod(images_shape[x]) for x in new_axes] - return numpy.transpose( - images, - axes=numpy.concatenate(new_axes) - ).reshape(new_shape) +#!/usr/bin/env python3 +""" Process training data for model training """ + +import logging + +from hashlib import sha1 +from random import shuffle + +import cv2 +import numpy as np +from scipy.interpolate import griddata + +from lib.model import masks +from lib.multithreading import MultiThread +from lib.queue_manager import queue_manager +from lib.umeyama import umeyama + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class TrainingDataGenerator(): + """ Generate training data for models """ + def __init__(self, model_input_size, model_output_size, training_opts): + logger.debug("Initializing %s: (model_input_size: %s, model_output_shape: %s, " + "training_opts: %s, landmarks: %s)", + self.__class__.__name__, model_input_size, model_output_size, + {key: val for key, val in training_opts.items() if key != "landmarks"}, + bool(training_opts.get("landmarks", None))) + self.batchsize = 0 + self.model_input_size = model_input_size + self.training_opts = training_opts + self.mask_function = self.set_mask_function() + self.landmarks = self.training_opts.get("landmarks", None) + + self.processing = ImageManipulation(model_input_size, + model_output_size, + training_opts.get("coverage_ratio", 0.625)) + logger.debug("Initialized %s", self.__class__.__name__) + + def set_mask_function(self): + """ Set the mask function to use if using mask """ + mask_type = self.training_opts.get("mask_type", None) + if mask_type: + logger.debug("Mask type: '%s'", mask_type) + mask_func = getattr(masks, mask_type) + else: + mask_func = None + logger.debug("Mask function: %s", mask_func) + return mask_func + + def minibatch_ab(self, images, batchsize, side, do_shuffle=True, is_timelapse=False): + """ Keep a queue filled to 8x Batch Size """ + logger.debug("Queue batches: (image_count: %s, batchsize: %s, side: '%s', do_shuffle: %s, " + "is_timelapse: %s)", len(images), batchsize, side, do_shuffle, is_timelapse) + self.batchsize = batchsize + q_name = "timelapse_{}".format(side) if is_timelapse else "train_{}".format(side) + q_size = batchsize * 8 + # Don't use a multiprocessing queue because sometimes the MP Manager borks on numpy arrays + queue_manager.add_queue(q_name, maxsize=q_size, multiprocessing_queue=False) + load_thread = MultiThread(self.load_batches, + images, + q_name, + side, + is_timelapse, + do_shuffle) + load_thread.start() + logger.debug("Batching to queue: (side: '%s', queue: '%s')", side, q_name) + return self.minibatch(q_name, load_thread) + + def load_batches(self, images, q_name, side, is_timelapse, do_shuffle=True): + """ Load the warped images and target images to queue """ + logger.debug("Loading batch: (image_count: %s, q_name: '%s', side: '%s', " + "is_timelapse: %s, do_shuffle: %s)", + len(images), q_name, side, is_timelapse, do_shuffle) + epoch = 0 + queue = queue_manager.get_queue(q_name) + self.validate_samples(images) + while True: + if do_shuffle: + shuffle(images) + for img in images: + logger.trace("Putting to batch queue: (q_name: '%s', side: '%s')", q_name, side) + queue.put(self.process_face(img, side, is_timelapse)) + epoch += 1 + logger.debug("Finished batching: (epoch: %s, q_name: '%s', side: '%s')", + epoch, q_name, side) + + def validate_samples(self, data): + """ Check the total number of images against batchsize and return + the total number of images """ + length = len(data) + msg = ("Number of images is lower than batch-size (Note that too few " + "images may lead to bad training). # images: {}, " + "batch-size: {}".format(length, self.batchsize)) + assert length >= self.batchsize, msg + + def minibatch(self, q_name, load_thread): + """ A generator function that yields epoch, batchsize of warped_img + and batchsize of target_img from the load queue """ + logger.debug("Launching minibatch generator for queue: '%s'", q_name) + queue = queue_manager.get_queue(q_name) + while True: + if load_thread.has_error: + logger.debug("Thread error detected") + break + batch = list() + for _ in range(self.batchsize): + images = queue.get() + for idx, image in enumerate(images): + if len(batch) < idx + 1: + batch.append(list()) + batch[idx].append(image) + batch = [np.float32(image) for image in batch] + logger.trace("Yielding batch: (size: %s, item shapes: %s, queue: '%s'", + len(batch), [item.shape for item in batch], q_name) + yield batch + logger.debug("Finished minibatch generator for queue: '%s'", q_name) + load_thread.join() + + def process_face(self, filename, side, is_timelapse): + """ Load an image and perform transformation and warping """ + logger.trace("Process face: (filename: '%s', side: '%s', is_timelapse: %s)", + filename, side, is_timelapse) + try: + image = cv2.imread(filename) # pylint: disable=no-member + except TypeError: + raise Exception("Error while reading image", filename) + + if self.mask_function or self.training_opts["warp_to_landmarks"]: + src_pts = self.get_landmarks(filename, image, side) + if self.mask_function: + image = self.mask_function(src_pts, image, channels=4) + + image = self.processing.color_adjust(image) + + if not is_timelapse: + image = self.processing.random_transform(image) + if not self.training_opts["no_flip"]: + image = self.processing.do_random_flip(image) + sample = image.copy()[:, :, :3] + + if self.training_opts["warp_to_landmarks"]: + dst_pts = self.get_closest_match(filename, side, src_pts) + processed = self.processing.random_warp_landmarks(image, src_pts, dst_pts) + else: + processed = self.processing.random_warp(image) + + processed.insert(0, sample) + logger.trace("Processed face: (filename: '%s', side: '%s', shapes: %s)", + filename, side, [img.shape for img in processed]) + return processed + + def get_landmarks(self, filename, image, side): + """ Return the landmarks for this face """ + logger.trace("Retrieving landmarks: (filename: '%s', side: '%s'", filename, side) + lm_key = sha1(image).hexdigest() + try: + src_points = self.landmarks[side][lm_key] + except KeyError: + raise Exception("Landmarks not found for hash: '{}' file: '{}'".format(lm_key, + filename)) + logger.trace("Returning: (src_points: %s)", src_points) + return src_points + + def get_closest_match(self, filename, side, src_points): + """ Return closest matched landmarks from opposite set """ + logger.trace("Retrieving closest matched landmarks: (filename: '%s', src_points: '%s'", + filename, src_points) + dst_points = self.landmarks["a"] if side == "b" else self.landmarks["b"] + dst_points = list(dst_points.values()) + closest = (np.mean(np.square(src_points - dst_points), + axis=(1, 2))).argsort()[:10] + closest = np.random.choice(closest) + dst_points = dst_points[closest] + logger.trace("Returning: (dst_points: %s)", dst_points) + return dst_points + + +class ImageManipulation(): + """ Manipulations to be performed on training images """ + def __init__(self, input_size, output_size, coverage_ratio): + """ input_size: Size of the face input into the model + output_size: Size of the face that comes out of the modell + coverage_ratio: Coverage ratio of full image. Eg: 256 * 0.625 = 160 + """ + logger.debug("Initializing %s: (input_size: %s, output_size: %s, coverage_ratio: %s)", + self.__class__.__name__, input_size, output_size, coverage_ratio) + # Transform args + self.rotation_range = 10 # Range to randomly rotate the image by + self.zoom_range = 0.05 # Range to randomly zoom the image by + self.shift_range = 0.05 # Range to randomly translate the image by + self.random_flip = 0.5 # Chance to flip the image horizontally + # Transform and Warp args + self.input_size = input_size + self.output_size = output_size + # Warp args + self.coverage_ratio = coverage_ratio # Coverage ratio of full image. Eg: 256 * 0.625 = 160 + self.scale = 5 # Normal random variable scale + logger.debug("Initialized %s", self.__class__.__name__) + + @staticmethod + def color_adjust(img): + """ Color adjust RGB image """ + logger.trace("Color adjusting image") + return img.astype('float32') / 255.0 + + @staticmethod + def separate_mask(image): + """ Return the image and the mask from a 4 channel image """ + mask = None + if image.shape[2] == 4: + logger.trace("Image contains mask") + mask = np.expand_dims(image[:, :, -1], axis=2) + image = image[:, :, :3] + else: + logger.trace("Image has no mask") + return image, mask + + def get_coverage(self, image): + """ Return coverage value for given image """ + coverage = int(image.shape[0] * self.coverage_ratio) + logger.trace("Coverage: %s", coverage) + return coverage + + def random_transform(self, image): + """ Randomly transform an image """ + logger.trace("Randomly transforming image") + height, width = image.shape[0:2] + + rotation = np.random.uniform(-self.rotation_range, self.rotation_range) + scale = np.random.uniform(1 - self.zoom_range, 1 + self.zoom_range) + tnx = np.random.uniform(-self.shift_range, self.shift_range) * width + tny = np.random.uniform(-self.shift_range, self.shift_range) * height + + mat = cv2.getRotationMatrix2D( # pylint: disable=no-member + (width // 2, height // 2), rotation, scale) + mat[:, 2] += (tnx, tny) + result = cv2.warpAffine( # pylint: disable=no-member + image, mat, (width, height), + borderMode=cv2.BORDER_REPLICATE) # pylint: disable=no-member + + logger.trace("Randomly transformed image") + return result + + def do_random_flip(self, image): + """ Perform flip on image if random number is within threshold """ + logger.trace("Randomly flipping image") + if np.random.random() < self.random_flip: + logger.trace("Flip within threshold. Flipping") + retval = image[:, ::-1] + else: + logger.trace("Flip outside threshold. Not Flipping") + retval = image + logger.trace("Randomly flipped image") + return retval + + def random_warp(self, image): + """ get pair of random warped images from aligned face image """ + logger.trace("Randomly warping image") + height, width = image.shape[0:2] + coverage = self.get_coverage(image) + assert height == width and height % 2 == 0 + + range_ = np.linspace(height // 2 - coverage // 2, + height // 2 + coverage // 2, + 5, dtype='float32') + mapx = np.broadcast_to(range_, (5, 5)).copy() + mapy = mapx.T + # mapx, mapy = np.float32(np.meshgrid(range_,range_)) # instead of broadcast + + pad = int(1.25 * self.input_size) + slices = slice(pad // 10, -pad // 10) + dst_slice = slice(0, (self.output_size + 1), (self.output_size // 4)) + interp = np.empty((2, self.input_size, self.input_size), dtype='float32') + #### + + for i, map_ in enumerate([mapx, mapy]): + map_ = map_ + np.random.normal(size=(5, 5), scale=self.scale) + interp[i] = cv2.resize(map_, (pad, pad))[slices, slices] # pylint: disable=no-member + + warped_image = cv2.remap( # pylint: disable=no-member + image, interp[0], interp[1], cv2.INTER_LINEAR) # pylint: disable=no-member + logger.trace("Warped image shape: %s", warped_image.shape) + + src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1) + dst_points = np.mgrid[dst_slice, dst_slice] + mat = umeyama(src_points, True, dst_points.T.reshape(-1, 2))[0:2] + target_image = cv2.warpAffine( # pylint: disable=no-member + image, mat, (self.output_size, self.output_size)) + logger.trace("Target image shape: %s", target_image.shape) + + warped_image, warped_mask = self.separate_mask(warped_image) + target_image, target_mask = self.separate_mask(target_image) + + if target_mask is None: + logger.trace("Randomly warped image") + return [warped_image, target_image] + + logger.trace("Target mask shape: %s", target_mask.shape) + logger.trace("Randomly warped image and mask") + return [warped_image, target_image, target_mask] + + def random_warp_landmarks(self, image, src_points=None, dst_points=None): + """ get warped image, target image and target mask + From DFAKER plugin """ + logger.trace("Randomly warping landmarks") + size = image.shape[0] + coverage = self.get_coverage(image) + + p_mx = size - 1 + p_hf = (size // 2) - 1 + + edge_anchors = [(0, 0), (0, p_mx), (p_mx, p_mx), (p_mx, 0), + (p_hf, 0), (p_hf, p_mx), (p_mx, p_hf), (0, p_hf)] + grid_x, grid_y = np.mgrid[0:p_mx:complex(size), 0:p_mx:complex(size)] + + source = src_points + destination = (dst_points.copy().astype('float32') + + np.random.normal(size=dst_points.shape, scale=2.0)) + destination = destination.astype('uint8') + + face_core = cv2.convexHull(np.concatenate( # pylint: disable=no-member + [source[17:], destination[17:]], axis=0).astype(int)) + + source = [(pty, ptx) for ptx, pty in source] + edge_anchors + destination = [(pty, ptx) for ptx, pty in destination] + edge_anchors + + indicies_to_remove = set() + for fpl in source, destination: + for idx, (pty, ptx) in enumerate(fpl): + if idx > 17: + break + elif cv2.pointPolygonTest(face_core, # pylint: disable=no-member + (pty, ptx), + False) >= 0: + indicies_to_remove.add(idx) + + for idx in sorted(indicies_to_remove, reverse=True): + source.pop(idx) + destination.pop(idx) + + grid_z = griddata(destination, source, (grid_x, grid_y), method="linear") + map_x = np.append([], [ar[:, 1] for ar in grid_z]).reshape(size, size) + map_y = np.append([], [ar[:, 0] for ar in grid_z]).reshape(size, size) + map_x_32 = map_x.astype('float32') + map_y_32 = map_y.astype('float32') + + warped_image = cv2.remap(image, # pylint: disable=no-member + map_x_32, + map_y_32, + cv2.INTER_LINEAR, # pylint: disable=no-member + cv2.BORDER_TRANSPARENT) # pylint: disable=no-member + target_image = image + + # TODO Make sure this replacement is correct + slices = slice(size // 2 - coverage // 2, size // 2 + coverage // 2) +# slices = slice(size // 32, size - size // 32) # 8px on a 256px image + warped_image = cv2.resize( # pylint: disable=no-member + warped_image[slices, slices, :], (self.input_size, self.input_size), + cv2.INTER_AREA) # pylint: disable=no-member + logger.trace("Warped image shape: %s", warped_image.shape) + target_image = cv2.resize( # pylint: disable=no-member + target_image[slices, slices, :], (self.output_size, self.output_size), + cv2.INTER_AREA) # pylint: disable=no-member + logger.trace("Target image shape: %s", target_image.shape) + + warped_image, warped_mask = self.separate_mask(warped_image) + target_image, target_mask = self.separate_mask(target_image) + + if target_mask is None: + logger.trace("Randomly warped image") + return [warped_image, target_image] + + logger.trace("Target mask shape: %s", target_mask.shape) + logger.trace("Randomly warped image and mask") + return [warped_image, target_image, target_mask] + + +def stack_images(images): + """ Stack images """ + logger.debug("Stack images") + + def get_transpose_axes(num): + if num % 2 == 0: + logger.debug("Even number of images to stack") + y_axes = list(range(1, num - 1, 2)) + x_axes = list(range(0, num - 1, 2)) + else: + logger.debug("Odd number of images to stack") + y_axes = list(range(0, num - 1, 2)) + x_axes = list(range(1, num - 1, 2)) + return y_axes, x_axes, [num - 1] + + images_shape = np.array(images.shape) + new_axes = get_transpose_axes(len(images_shape)) + new_shape = [np.prod(images_shape[x]) for x in new_axes] + logger.debug("Stacked images") + return np.transpose( + images, + axes=np.concatenate(new_axes) + ).reshape(new_shape) diff --git a/lib/umeyama.py b/lib/umeyama.py index a83548491f..f3af365a9a 100644 --- a/lib/umeyama.py +++ b/lib/umeyama.py @@ -12,8 +12,27 @@ import numpy as np +MEAN_FACE_X = np.array([ + 0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, + 0.799124, 0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, + 0.36688, 0.426036, 0.490127, 0.554217, 0.613373, 0.121737, 0.187122, + 0.265825, 0.334606, 0.260918, 0.182743, 0.645647, 0.714428, 0.793132, + 0.858516, 0.79751, 0.719335, 0.254149, 0.340985, 0.428858, 0.490127, + .551395, 0.639268, 0.726104, 0.642159, 0.556721, 0.490127, 0.423532, + 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, 0.553364, + 0.490127, 0.42689]) -def umeyama(src, dst, estimate_scale): +MEAN_FACE_Y = np.array([ + 0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, + 0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, + 0.587326, 0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, + 0.179852, 0.231733, 0.245099, 0.244077, 0.231733, 0.179852, 0.178758, + 0.216423, 0.244077, 0.245099, 0.780233, 0.745405, 0.727388, 0.742578, + 0.727388, 0.745405, 0.780233, 0.864805, 0.902192, 0.909281, 0.902192, + 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, 0.784792, 0.824182, + 0.831803, 0.824182]) + +def umeyama(src, estimate_scale, dst=None): """Estimate N-D similarity transformation with or without scaling. Parameters ---------- @@ -33,6 +52,8 @@ def umeyama(src, dst, estimate_scale): .. [1] "Least-squares estimation of transformation parameters between two point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 """ + if dst is None: + dst = np.stack([MEAN_FACE_X, MEAN_FACE_Y], axis=1) num = src.shape[0] dim = src.shape[1] diff --git a/lib/utils.py b/lib/utils.py index 29193e9867..72d5dd2f15 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -8,7 +8,6 @@ from hashlib import sha1 from pathlib import Path from re import finditer -from time import time import cv2 import numpy as np @@ -16,7 +15,6 @@ import dlib from lib.faces_detect import DetectedFace -from lib.training_data import TrainingDataGenerator from lib.logger import get_loglevel @@ -62,7 +60,7 @@ def get_image_paths(directory): def hash_image_file(filename): - """ Return the filename with it's sha1 hash """ + """ Return an image file's sha1 hash """ img = cv2.imread(filename) # pylint: disable=no-member img_hash = sha1(img).hexdigest() logger.trace("filename: '%s', hash: %s", filename, img_hash) @@ -107,33 +105,12 @@ def set_system_verbosity(loglevel): logger.debug("System Verbosity level: %s", loglevel) os.environ['TF_CPP_MIN_LOG_LEVEL'] = loglevel if loglevel != '0': - for warncat in (FutureWarning, DeprecationWarning): + for warncat in (FutureWarning, DeprecationWarning, UserWarning): warnings.simplefilter(action='ignore', category=warncat) -def add_alpha_channel(image, intensity=100): - """ Add an alpha channel to an image - - intensity: The opacity of the alpha channel between 0 and 100 - 100 = transparent, - 0 = solid """ - logger.trace("Adding alpha channel: intensity: %s", intensity) - assert 0 <= intensity <= 100, "Invalid intensity supplied" - intensity = (255.0 / 100.0) * intensity - - d_type = image.dtype - image = image.astype("float32") - - ch_b, ch_g, ch_r = cv2.split(image) # pylint: disable=no-member - ch_a = np.ones(ch_b.shape, dtype="float32") * intensity - - image_bgra = cv2.merge( # pylint: disable=no-member - (ch_b, ch_g, ch_r, ch_a)) - logger.trace("Added alpha channel", intensity) - return image_bgra.astype(d_type) - - def rotate_landmarks(face, rotation_matrix): + # pylint: disable=c-extension-no-member """ Rotate the landmarks and bounding box for faces found in rotated images. Pass in a DetectedFace object, Alignments dict or DLib rectangle""" @@ -223,80 +200,6 @@ def camel_case_split(identifier): return [m.group(0) for m in matches] -class Timelapse: - """ Time lapse function for training """ - @classmethod - def create_timelapse(cls, input_dir_a, input_dir_b, output_dir, trainer): - """ Create the time lapse """ - if input_dir_a is None and input_dir_b is None and output_dir is None: - return None - - if input_dir_a is None or input_dir_b is None: - raise ValueError("To enable the timelapse, you have to supply " - "all the parameters (--timelapse-input-A and " - "--timelapse-input-B).") - - if output_dir is None: - output_dir = get_folder(os.path.join(trainer.model.model_dir, - "timelapse")) - - return Timelapse(input_dir_a, input_dir_b, output_dir, trainer) - - def __init__(self, input_dir_a, input_dir_b, output, trainer): - self.output_dir = output - self.trainer = trainer - - if not os.path.isdir(self.output_dir): - logger.error("'%s' does not exist", self.output_dir) - exit(1) - - self.files_a = self.read_input_images(input_dir_a) - self.files_b = self.read_input_images(input_dir_b) - - btchsz = min(len(self.files_a), len(self.files_b)) - - self.images_a = self.get_image_data(self.files_a, btchsz) - self.images_b = self.get_image_data(self.files_b, btchsz) - - @staticmethod - def read_input_images(input_dir): - """ Get the image paths """ - if not os.path.isdir(input_dir): - logger.error("'%s' does not exist", input_dir) - exit(1) - - if not os.listdir(input_dir): - logger.error("'%s' contains no images", input_dir) - exit(1) - - return get_image_paths(input_dir) - - def get_image_data(self, input_images, batch_size): - """ Get training images """ - random_transform_args = { - 'rotation_range': 0, - 'zoom_range': 0, - 'shift_range': 0, - 'random_flip': 0 - } - - zoom = 1 - if hasattr(self.trainer.model, 'IMAGE_SHAPE'): - zoom = self.trainer.model.IMAGE_SHAPE[0] // 64 - - generator = TrainingDataGenerator(random_transform_args, 160, zoom) - batch = generator.minibatchAB(input_images, batch_size, - doShuffle=False) - - return next(batch)[2] - - def work(self): - """ Write out timelapse image """ - image = self.trainer.show_sample(self.images_a, self.images_b) - cv2.imwrite(os.path.join(self.output_dir, # pylint: disable=no-member - str(int(time())) + ".png"), image) - - def safe_shutdown(): """ Close queues, threads and processes in event of crash """ logger.debug("Safely shutting down") diff --git a/plugins/convert/Convert_Adjust.py b/plugins/convert/Convert_Adjust.py deleted file mode 100644 index 4b4217ff00..0000000000 --- a/plugins/convert/Convert_Adjust.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -""" Adjust converter for faceswap.py - - Based on the original https://www.reddit.com/r/deepfakes/ code sample - Adjust code made by https://github.com/yangchen8710 """ - -import cv2 -import numpy as np - -from lib.utils import add_alpha_channel - - -class Convert(): - """ Adjust Converter """ - def __init__(self, encoder, smooth_mask=True, avg_color_adjust=True, - draw_transparent=False, **kwargs): - self.encoder = encoder - - self.use_smooth_mask = smooth_mask - self.use_avg_color_adjust = avg_color_adjust - self.draw_transparent = draw_transparent - - def patch_image(self, frame, detected_face, size): - """ Patch swapped face onto original image """ - # pylint: disable=no-member - # assert image.shape == (256, 256, 3) - padding = 48 - face_size = 256 - detected_face.load_aligned(frame, face_size, padding, - align_eyes=False) - src_face = detected_face.aligned_face - - crop = slice(padding, face_size - padding) - process_face = src_face[crop, crop] - old_face = process_face.copy() - - process_face = cv2.resize(process_face, - (size, size), - interpolation=cv2.INTER_AREA) - process_face = np.expand_dims(process_face, 0) - - new_face = self.encoder(process_face / 255.0)[0] - new_face = np.clip(new_face * 255, 0, 255).astype(src_face.dtype) - new_face = cv2.resize( - new_face, - (face_size - padding * 2, face_size - padding * 2), - interpolation=cv2.INTER_CUBIC) - - if self.use_avg_color_adjust: - self.adjust_avg_color(old_face, new_face) - if self.use_smooth_mask: - self.smooth_mask(old_face, new_face) - - new_face = self.superpose(src_face, new_face, crop) - new_image = frame.copy() - - if self.draw_transparent: - new_image, new_face = self.convert_transparent(new_image, - new_face) - - cv2.warpAffine( - new_face, - detected_face.adjusted_matrix, - (detected_face.frame_dims[1], detected_face.frame_dims[0]), - new_image, - flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, - borderMode=cv2.BORDER_TRANSPARENT) - return new_image - - @staticmethod - def adjust_avg_color(old_face, new_face): - """ Perform average color adjustment """ - for i in range(new_face.shape[-1]): - old_avg = old_face[:, :, i].mean() - new_avg = new_face[:, :, i].mean() - diff_int = (int)(old_avg - new_avg) - for int_h in range(new_face.shape[0]): - for int_w in range(new_face.shape[1]): - temp = (new_face[int_h, int_w, i] + diff_int) - if temp < 0: - new_face[int_h, int_w, i] = 0 - elif temp > 255: - new_face[int_h, int_w, i] = 255 - else: - new_face[int_h, int_w, i] = temp - - @staticmethod - def smooth_mask(old_face, new_face): - """ Smooth the mask """ - width, height, _ = new_face.shape - crop = slice(0, width) - mask = np.zeros_like(new_face) - mask[height // 15:-height // 15, width // 15:-width // 15, :] = 255 - mask = cv2.GaussianBlur(mask, # pylint: disable=no-member - (15, 15), - 10) - new_face[crop, crop] = (mask / 255 * new_face + - (1 - mask / 255) * old_face) - - @staticmethod - def superpose(src_face, new_face, crop): - """ Crop Face """ - new_image = src_face.copy() - new_image[crop, crop] = new_face - return new_image - - @staticmethod - def convert_transparent(image, new_face): - """ Add alpha channels to images and change to - transparent background """ - height, width = image.shape[:2] - image = np.zeros((height, width, 4), dtype=np.uint8) - new_face = add_alpha_channel(new_face, 100) - return image, new_face diff --git a/plugins/convert/Convert_Masked.py b/plugins/convert/Convert_Masked.py deleted file mode 100644 index a407f5178c..0000000000 --- a/plugins/convert/Convert_Masked.py +++ /dev/null @@ -1,230 +0,0 @@ -#!/usr/bin/env python3 -""" Masked converter for faceswap.py - Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955 - found on https://www.reddit.com/r/deepfakes/ """ - -import logging -import cv2 -import numpy - -from lib.aligner import get_align_mat -from lib.utils import add_alpha_channel - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class Convert(): - def __init__(self, encoder, trainer, - blur_size=2, seamless_clone=False, mask_type="facehullandrect", - erosion_kernel_size=None, match_histogram=False, sharpen_image=None, - draw_transparent=False, **kwargs): - self.encoder = encoder - self.trainer = trainer - self.erosion_kernel = None - self.erosion_kernel_size = erosion_kernel_size - if erosion_kernel_size is not None: - if erosion_kernel_size > 0: - self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (erosion_kernel_size, - erosion_kernel_size)) - elif erosion_kernel_size < 0: - self.erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, - (abs(erosion_kernel_size), - abs(erosion_kernel_size))) - self.blur_size = blur_size - self.seamless_clone = seamless_clone - self.sharpen_image = sharpen_image - self.match_histogram = match_histogram - self.mask_type = mask_type.lower() # Choose in 'FaceHullAndRect', 'FaceHull', 'Rect' - self.draw_transparent = draw_transparent - - def patch_image(self, image, face_detected, size): - - image_size = image.shape[1], image.shape[0] - - mat = numpy.array(get_align_mat(face_detected, - size, - should_align_eyes=False)).reshape(2, 3) - - if "GAN" not in self.trainer: - mat = mat * size - else: - padding = int(48/256*size) - mat = mat * (size - 2 * padding) - mat[:, 2] += padding - - new_face = self.get_new_face(image, mat, size) - - image_mask = self.get_image_mask(image, - new_face, - face_detected.landmarks_as_xy, - mat, - image_size) - - return self.apply_new_face(image, new_face, image_mask, mat, image_size, size) - - @staticmethod - def convert_transparent(image, new_face, image_mask, image_size): - """ Add alpha channels to images and change to - transparent background """ - image = numpy.zeros((image_size[1], image_size[0], 4), - dtype=numpy.uint8) - image_mask = add_alpha_channel(image_mask, 100) - new_face = add_alpha_channel(new_face, 100) - return image, new_face, image_mask - - def apply_new_face(self, image, new_face, image_mask, mat, image_size, size): - - if self.draw_transparent: - image, new_face, image_mask = self.convert_transparent(image, - new_face, - image_mask, - image_size) - self.seamless_clone = False # Alpha channel not supported in seamless - base_image = numpy.copy(image) - new_image = numpy.copy(image) - - cv2.warpAffine(new_face, - mat, - image_size, - new_image, - cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, - cv2.BORDER_TRANSPARENT) - - if self.sharpen_image == "bsharpen": - # Sharpening using filter2D - kernel = numpy.ones((3, 3)) * (-1) - kernel[1, 1] = 9 - new_image = cv2.filter2D(new_image, -1, kernel) - elif self.sharpen_image == "gsharpen": - # Sharpening using Weighted Method - gaussain_blur = cv2.GaussianBlur(new_image, (0, 0), 3.0) - new_image = cv2.addWeighted( - new_image, 1.5, gaussain_blur, -0.5, 0, new_image) - - outimage = None - if self.seamless_clone: - unitMask = numpy.clip(image_mask * 365, 0, 255).astype(numpy.uint8) - logger.info(unitMask.shape) - logger.info(new_image.shape) - logger.info(base_image.shape) - maxregion = numpy.argwhere(unitMask == 255) - - if maxregion.size > 0: - miny, minx = maxregion.min(axis=0)[:2] - maxy, maxx = maxregion.max(axis=0)[:2] - lenx = maxx - minx - leny = maxy - miny - masky = int(minx + (lenx // 2)) - maskx = int(miny + (leny // 2)) - outimage = cv2.seamlessClone(new_image.astype(numpy.uint8), - base_image.astype(numpy.uint8), - unitMask, - (masky, maskx), - cv2.NORMAL_CLONE) - return outimage - - foreground = cv2.multiply(image_mask, new_image.astype(float)) - background = cv2.multiply(1.0 - image_mask, base_image.astype(float)) - outimage = cv2.add(foreground, background) - - return outimage - - def hist_match(self, source, template, mask=None): - # Code borrowed from: - # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x - masked_source = source - masked_template = template - - if mask is not None: - masked_source = source * mask - masked_template = template * mask - - oldshape = source.shape - source = source.ravel() - template = template.ravel() - masked_source = masked_source.ravel() - masked_template = masked_template.ravel() - s_values, bin_idx, s_counts = numpy.unique(source, return_inverse=True, - return_counts=True) - t_values, t_counts = numpy.unique(template, return_counts=True) - ms_values, mbin_idx, ms_counts = numpy.unique(source, return_inverse=True, - return_counts=True) - mt_values, mt_counts = numpy.unique(template, return_counts=True) - - s_quantiles = numpy.cumsum(s_counts).astype(numpy.float64) - s_quantiles /= s_quantiles[-1] - t_quantiles = numpy.cumsum(t_counts).astype(numpy.float64) - t_quantiles /= t_quantiles[-1] - interp_t_values = numpy.interp(s_quantiles, t_quantiles, t_values) - - return interp_t_values[bin_idx].reshape(oldshape) - - def color_hist_match(self, src_im, tar_im, mask): - matched_R = self.hist_match(src_im[:, :, 0], tar_im[:, :, 0], mask) - matched_G = self.hist_match(src_im[:, :, 1], tar_im[:, :, 1], mask) - matched_B = self.hist_match(src_im[:, :, 2], tar_im[:, :, 2], mask) - matched = numpy.stack((matched_R, matched_G, matched_B), axis=2).astype(src_im.dtype) - return matched - - def get_new_face(self, image, mat, size): - face = cv2.warpAffine(image, mat, (size, size)) - face = numpy.expand_dims(face, 0) - face_clipped = numpy.clip(face[0], 0, 255).astype(image.dtype) - new_face = None - mask = None - - if "GAN" not in self.trainer: - normalized_face = face / 255.0 - new_face = self.encoder(normalized_face)[0] - new_face = numpy.clip(new_face * 255, 0, 255).astype(image.dtype) - else: - normalized_face = face / 255.0 * 2 - 1 - fake_output = self.encoder(normalized_face) - if "128" in self.trainer: # TODO: Another hack to switch between 64 and 128 - fake_output = fake_output[0] - mask = fake_output[:, :, :, :1] - new_face = fake_output[:, :, :, 1:] - new_face = mask * new_face + (1 - mask) * normalized_face - new_face = numpy.clip((new_face[0] + 1) * 255 / 2, 0, 255).astype(image.dtype) - - if self.match_histogram: - new_face = self.color_hist_match(new_face, face_clipped, mask) - - return new_face - - def get_image_mask(self, image, new_face, landmarks, mat, image_size): - - face_mask = numpy.zeros(image.shape, dtype=float) - if 'rect' in self.mask_type: - face_src = numpy.ones(new_face.shape, dtype=float) - cv2.warpAffine(face_src, - mat, - image_size, - face_mask, - cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC, cv2.BORDER_TRANSPARENT) - - hull_mask = numpy.zeros(image.shape, dtype=float) - if 'hull' in self.mask_type: - hull = cv2.convexHull( - numpy.array(landmarks).reshape((-1, 2)).astype(int)).flatten().reshape((-1, 2)) - cv2.fillConvexPoly(hull_mask, hull, (1, 1, 1)) - - if self.mask_type == 'rect': - image_mask = face_mask - elif self.mask_type == 'facehull': - image_mask = hull_mask - else: - image_mask = ((face_mask*hull_mask)) - - if self.erosion_kernel is not None: - if self.erosion_kernel_size > 0: - image_mask = cv2.erode(image_mask, self.erosion_kernel, iterations=1) - elif self.erosion_kernel_size < 0: - dilation_kernel = abs(self.erosion_kernel) - image_mask = cv2.dilate(image_mask, dilation_kernel, iterations=1) - - if self.blur_size != 0: - image_mask = cv2.blur(image_mask, (self.blur_size, self.blur_size)) - - return image_mask diff --git a/plugins/convert/masked.py b/plugins/convert/masked.py new file mode 100644 index 0000000000..f3415b5b5d --- /dev/null +++ b/plugins/convert/masked.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +""" Masked converter for faceswap.py + Based on: https://gist.github.com/anonymous/d3815aba83a8f79779451262599b0955 + found on https://www.reddit.com/r/deepfakes/ """ + +import logging +import cv2 +import numpy as np +from lib.model.masks import dfl_full + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class Convert(): + """ Swap a source face with a target """ + def __init__(self, encoder, model, arguments): + logger.debug("Initializing %s: (encoder: '%s', model: %s, arguments: %s", + self.__class__.__name__, encoder, model, arguments) + self.encoder = encoder + self.args = arguments + self.input_size = model.input_shape[0] + self.training_size = model.state.training_size + self.training_coverage_ratio = model.training_opts["coverage_ratio"] + self.input_mask_shape = model.state.mask_shapes[0] if model.state.mask_shapes else None + + self.crop = None + self.mask = None + logger.debug("Initialized %s", self.__class__.__name__) + + def patch_image(self, image, detected_face): + """ Patch the image """ + logger.trace("Patching image") + image = image.astype('float32') + image_size = (image.shape[1], image.shape[0]) + coverage = int(self.training_coverage_ratio * self.training_size) + padding = (self.training_size - coverage) // 2 + logger.trace("coverage: %s, padding: %s", coverage, padding) + + self.crop = slice(padding, self.training_size - padding) + if not self.mask: # Init the mask on first image + self.mask = Mask(self.args.mask_type, self.training_size, padding, self.crop) + + detected_face.load_aligned(image, size=self.training_size, align_eyes=False) + new_image = self.get_new_image(image, detected_face, coverage, image_size) + image_mask = self.get_image_mask(detected_face, image_size) + patched_face = self.apply_fixes(image, + new_image, + image_mask, + image_size) + + logger.trace("Patched image") + return patched_face + + def get_new_image(self, image, detected_face, coverage, image_size): + """ Get the new face from the predictor """ + logger.trace("coverage: %s", coverage) + src_face = detected_face.aligned_face + coverage_face = src_face[self.crop, self.crop] + coverage_face = cv2.resize(coverage_face, # pylint: disable=no-member + (self.input_size, self.input_size), + interpolation=cv2.INTER_AREA) # pylint: disable=no-member + coverage_face = np.expand_dims(coverage_face, 0) + np.clip(coverage_face / 255.0, 0.0, 1.0, out=coverage_face) + + if self.input_mask_shape: + mask = np.zeros(self.input_mask_shape, np.float32) + mask = np.expand_dims(mask, 0) + feed = [coverage_face, mask] + else: + feed = [coverage_face] + logger.trace("Input shapes: %s", [item.shape for item in feed]) + new_face = self.encoder(feed)[0] + new_face = new_face.squeeze() + logger.trace("Output shape: %s", new_face.shape) + + new_face = cv2.resize(new_face, # pylint: disable=no-member + (coverage, coverage), + interpolation=cv2.INTER_CUBIC) # pylint: disable=no-member + np.clip(new_face * 255.0, 0.0, 255.0, out=new_face) + src_face[self.crop, self.crop] = new_face + background = image.copy() + interpolator = detected_face.adjusted_interpolators[1] + new_image = cv2.warpAffine( # pylint: disable=no-member + src_face, + detected_face.adjusted_matrix, + image_size, + background, + flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member + borderMode=cv2.BORDER_TRANSPARENT) # pylint: disable=no-member + return new_image + + def get_image_mask(self, detected_face, image_size): + """ Get the image mask """ + mask = self.mask.get_mask(detected_face, image_size) + if self.args.erosion_size != 0: + kwargs = {'src': mask, + 'kernel': self.set_erosion_kernel(mask), + 'iterations': 1} + if self.args.erosion_size > 0: + mask = cv2.erode(**kwargs) # pylint: disable=no-member + else: + mask = cv2.dilate(**kwargs) # pylint: disable=no-member + + if self.args.blur_size != 0: + blur_size = self.set_blur_size(mask) + mask = cv2.blur(mask, (blur_size, blur_size)) # pylint: disable=no-member + + return np.clip(mask, 0.0, 1.0, out=mask) + + def set_erosion_kernel(self, mask): + """ Set the erosion kernel """ + erosion_ratio = self.args.erosion_size / 100 + mask_radius = np.sqrt(np.sum(mask)) / 2 + percent_erode = max(1, int(abs(erosion_ratio * mask_radius))) + erosion_kernel = cv2.getStructuringElement( # pylint: disable=no-member + cv2.MORPH_ELLIPSE, # pylint: disable=no-member + (percent_erode, percent_erode)) + logger.trace("erosion_kernel shape: %s", erosion_kernel.shape) + return erosion_kernel + + def set_blur_size(self, mask): + """ Set the blur size to absolute or percentage """ + blur_ratio = self.args.blur_size / 100 + mask_radius = np.sqrt(np.sum(mask)) / 2 + blur_size = int(max(1, blur_ratio * mask_radius)) + logger.trace("blur_size: %s", blur_size) + return blur_size + + def apply_fixes(self, frame, new_image, image_mask, image_size): + """ Apply fixes """ + masked = new_image # * image_mask + + if self.args.draw_transparent: + alpha = np.full((image_size[1], image_size[0], 1), 255.0, dtype='float32') + new_image = np.concatenate(new_image, alpha, axis=2) + image_mask = np.concatenate(image_mask, alpha, axis=2) + frame = np.concatenate(frame, alpha, axis=2) + + if self.args.sharpen_image is not None: + np.clip(masked, 0.0, 255.0, out=masked) + if self.args.sharpen_image == "box_filter": + kernel = np.ones((3, 3)) * (-1) + kernel[1, 1] = 9 + masked = cv2.filter2D(masked, -1, kernel) # pylint: disable=no-member + elif self.args.sharpen_image == "gaussian_filter": + blur = cv2.GaussianBlur(masked, (0, 0), 3.0) # pylint: disable=no-member + masked = cv2.addWeighted(masked, # pylint: disable=no-member + 1.5, + blur, + -0.5, + 0, + masked) + + if self.args.avg_color_adjust: + for _ in [0, 1]: + np.clip(masked, 0.0, 255.0, out=masked) + diff = frame - masked + avg_diff = np.sum(diff * image_mask, axis=(0, 1)) + adjustment = avg_diff / np.sum(image_mask, axis=(0, 1)) + masked = masked + adjustment + + if self.args.match_histogram: + np.clip(masked, 0.0, 255.0, out=masked) + masked = self.color_hist_match(masked, frame, image_mask) + + if self.args.seamless_clone and not self.args.draw_transparent: + h, w, _ = frame.shape + h = h // 2 + w = w // 2 + + y_indices, x_indices, _ = np.nonzero(image_mask) + y_crop = slice(np.min(y_indices), np.max(y_indices)) + x_crop = slice(np.min(x_indices), np.max(x_indices)) + y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2) + h) + x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2) + w) + + ''' + # test with average of centroid rather than the h /2 , w/2 center + y_center = int(np.rint(np.average(y_indices) + h) + x_center = int(np.rint(np.average(x_indices) + w) + ''' + + insertion = np.rint(masked[y_crop, x_crop, :]).astype('uint8') + insertion_mask = image_mask[y_crop, x_crop, :] + insertion_mask[insertion_mask != 0] = 255 + insertion_mask = insertion_mask.astype('uint8') + + prior = np.pad(frame, ((h, h), (w, w), (0, 0)), 'constant').astype('uint8') + + blended = cv2.seamlessClone(insertion, # pylint: disable=no-member + prior, + insertion_mask, + (x_center, y_center), + cv2.NORMAL_CLONE) # pylint: disable=no-member + blended = blended[h:-h, w:-w, :] + + else: + foreground = masked * image_mask + background = frame * (1.0 - image_mask) + blended = foreground + background + + np.clip(blended, 0.0, 255.0, out=blended) + + return np.rint(blended).astype('uint8') + + def color_hist_match(self, new, frame, image_mask): + for channel in [0, 1, 2]: + new[:, :, channel] = self.hist_match(new[:, :, channel], + frame[:, :, channel], + image_mask[:, :, channel]) + # source = np.stack([self.hist_match(source[:,:,c], target[:,:,c],image_mask[:,:,c]) + # for c in [0,1,2]], + # axis=2) + return new + + def hist_match(self, new, frame, image_mask): + + mask_indices = np.nonzero(image_mask) + if len(mask_indices[0]) == 0: + return new + + m_new = new[mask_indices].ravel() + m_frame = frame[mask_indices].ravel() + s_values, bin_idx, s_counts = np.unique(m_new, return_inverse=True, return_counts=True) + t_values, t_counts = np.unique(m_frame, return_counts=True) + s_quants = np.cumsum(s_counts, dtype='float32') + t_quants = np.cumsum(t_counts, dtype='float32') + s_quants /= s_quants[-1] # cdf + t_quants /= t_quants[-1] # cdf + interp_s_values = np.interp(s_quants, t_quants, t_values) + new.put(mask_indices, interp_s_values[bin_idx]) + + ''' + bins = np.arange(256) + template_CDF, _ = np.histogram(m_frame, bins=bins, density=True) + flat_new_image = np.interp(m_source.ravel(), bins[:-1], template_CDF) * 255.0 + return flat_new_image.reshape(m_source.shape) * 255.0 + ''' + + return new + + +class Mask(): + """ Return the requested mask """ + + def __init__(self, mask_type, training_size, padding, crop): + """ Set requested mask """ + logger.debug("Initializing %s: (mask_type: '%s', training_size: %s, padding: %s)", + self.__class__.__name__, mask_type, training_size, padding) + + self.training_size = training_size + self.padding = padding + self.mask_type = mask_type + self.crop = crop + + logger.debug("Initialized %s", self.__class__.__name__) + + def get_mask(self, detected_face, image_size): + """ Return a face mask """ + kwargs = {"matrix": detected_face.adjusted_matrix, + "interpolators": detected_face.adjusted_interpolators, + "landmarks": detected_face.landmarks_as_xy, + "image_size": image_size} + logger.trace("kwargs: %s", kwargs) + mask = getattr(self, self.mask_type)(**kwargs) + mask = self.finalize_mask(mask) + logger.trace("mask shape: %s", mask.shape) + return mask + + def cnn(self, **kwargs): + """ CNN Mask """ + # Insert FCN-VGG16 segmentation mask model here + logger.info("cnn not yet implemented, using facehull instead") + return self.facehull(**kwargs) + + def smoothed(self, **kwargs): + """ Smoothed Mask """ + logger.trace("Getting mask") + interpolator = kwargs["interpolators"][1] + ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32') + # area = self.padding + (self.training_size - 2 * self.padding) // 15 + # central_core = slice(area, -area) + ones[self.crop, self.crop] = 1.0 + ones = cv2.GaussianBlur(ones, (25, 25), 10) # pylint: disable=no-member + + mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') + cv2.warpAffine(ones, # pylint: disable=no-member + kwargs["matrix"], + kwargs["image_size"], + mask, + flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member + borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member + borderValue=0.0) + return mask + + def rect(self, **kwargs): + """ Rect Mask """ + logger.trace("Getting mask") + interpolator = kwargs["interpolators"][1] + ones = np.zeros((self.training_size, self.training_size, 3), dtype='float32') + mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') + # central_core = slice(self.padding, -self.padding) + ones[self.crop, self.crop] = 1.0 + cv2.warpAffine(ones, # pylint: disable=no-member + kwargs["matrix"], + kwargs["image_size"], + mask, + flags=cv2.WARP_INVERSE_MAP | interpolator, # pylint: disable=no-member + borderMode=cv2.BORDER_CONSTANT, # pylint: disable=no-member + borderValue=0.0) + return mask + + def dfl(self, **kwargs): + """ DFaker Mask """ + logger.trace("Getting mask") + dummy = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') + mask = dfl_full(kwargs["landmarks"], dummy, channels=3) + return mask + + def facehull(self, **kwargs): + """ Facehull Mask """ + logger.trace("Getting mask") + mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') + hull = cv2.convexHull( # pylint: disable=no-member + np.array(kwargs["landmarks"]).reshape((-1, 2))) + cv2.fillConvexPoly(mask, # pylint: disable=no-member + hull, + (1.0, 1.0, 1.0), + lineType=cv2.LINE_AA) # pylint: disable=no-member + return mask + + def facehull_rect(self, **kwargs): + """ Facehull Rect Mask """ + logger.trace("Getting mask") + mask = self.rect(**kwargs) + hull_mask = self.facehull(**kwargs) + mask *= hull_mask + return mask + + def ellipse(self, **kwargs): + """ Ellipse Mask """ + logger.trace("Getting mask") + mask = np.zeros((kwargs["image_size"][1], kwargs["image_size"][0], 3), dtype='float32') + ell = cv2.fitEllipse( # pylint: disable=no-member + np.array(kwargs["landmarks"]).reshape((-1, 2))) + cv2.ellipse(mask, # pylint: disable=no-member + box=ell, + color=(1.0, 1.0, 1.0), + thickness=-1) + return mask + + @staticmethod + def finalize_mask(mask): + """ Finalize the mask """ + logger.trace("Finalizing mask") + np.nan_to_num(mask, copy=False) + np.clip(mask, 0.0, 1.0, out=mask) + return mask diff --git a/plugins/extract/_config.py b/plugins/extract/_config.py new file mode 100644 index 0000000000..092b0665d0 --- /dev/null +++ b/plugins/extract/_config.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" Default configurations for extract """ + +import logging + +from lib.config import FaceswapConfig + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class Config(FaceswapConfig): + """ Config File for Models """ + + def set_defaults(self): + """ Set the default values for config """ + logger.debug("Setting defaults") + + # << GLOBAL OPTIONS >> # +# section = "global" +# self.add_section(title=section, +# info="Options that apply to all models") + + # << MTCNN DETECTOR OPTIONS >> # + section = "detect.mtcnn" + self.add_section(title=section, + info="MTCNN Detector options") + self.add_item( + section=section, title="minsize", datatype=int, default=20, rounding=10, + min_max=(20, 1000), + info="The minimum size of a face (in pixels) to be accepted as a positive match.\n" + "Lower values use significantly more VRAM and will detect more false positives") + self.add_item( + section=section, title="threshold_1", datatype=float, default=0.6, rounding=2, + min_max=(0.1, 0.9), + info="First stage threshold for face detection. This stage obtains face candidates") + self.add_item( + section=section, title="threshold_2", datatype=float, default=0.7, rounding=2, + min_max=(0.1, 0.9), + info="Second stage threshold for face detection. This stage refines face candidates") + self.add_item( + section=section, title="threshold_3", datatype=float, default=0.7, rounding=2, + min_max=(0.1, 0.9), + info="Third stage threshold for face detection. This stage further refines face " + "candidates") + self.add_item( + section=section, title="scalefactor", datatype=float, default=0.709, rounding=3, + min_max=(0.1, 0.9), + info="The scale factor for the image pyramid") diff --git a/plugins/extract/detect/_base.py b/plugins/extract/detect/_base.py index 88c1be5a5c..7ffa788104 100644 --- a/plugins/extract/detect/_base.py +++ b/plugins/extract/detect/_base.py @@ -22,14 +22,21 @@ from lib.gpu_stats import GPUStats from lib.utils import rotate_landmarks +from plugins.extract._config import Config logger = logging.getLogger(__name__) # pylint: disable=invalid-name +def get_config(plugin_name): + """ Return the config for the requested model """ + return Config(plugin_name).config_dict + + class Detector(): """ Detector object """ def __init__(self, loglevel, rotation=None): logger.debug("Initializing %s: (rotation: %s)", self.__class__.__name__, rotation) + self.config = get_config(".".join(self.__module__.split(".")[-2:])) self.loglevel = loglevel self.cachepath = os.path.join(os.path.dirname(__file__), ".cache") self.rotation = self.get_rotation_angles(rotation) @@ -107,6 +114,7 @@ def run(self, *args, **kwargs): logger.exception("Traceback:") tb_buffer = StringIO() traceback.print_exc(file=tb_buffer) + logger.trace(tb_buffer.getvalue()) exception = {"exception": (os.getpid(), tb_buffer)} self.queues["out"].put(exception) exit(1) diff --git a/plugins/extract/detect/dlib_cnn.py b/plugins/extract/detect/dlib_cnn.py index 69bb5c3d61..c506b646c5 100644 --- a/plugins/extract/detect/dlib_cnn.py +++ b/plugins/extract/detect/dlib_cnn.py @@ -111,7 +111,7 @@ def compile_detection_images(self, images): def detect_batch(self, detect_images, disable_message=False): """ Pass the batch through detector for consistently sized images - or each image seperately for inconsitently sized images """ + or each image separately for inconsitently sized images """ logger.trace("Detecting Batch") can_batch = self.check_batch_dims(detect_images) if can_batch: diff --git a/plugins/extract/detect/mtcnn.py b/plugins/extract/detect/mtcnn.py index bf28a641f0..d6aa42c38d 100644 --- a/plugins/extract/detect/mtcnn.py +++ b/plugins/extract/detect/mtcnn.py @@ -30,21 +30,24 @@ class Detect(Detector): """ MTCNN detector for face recognition """ def __init__(self, **kwargs): super().__init__(**kwargs) - self.kwargs = None + self.kwargs = self.validate_kwargs() self.name = "mtcnn" self.target = 2073600 # Uses approx 1.30 GB of VRAM self.vram = 1408 - @staticmethod - def validate_kwargs(kwargs): - """ Validate that cli kwargs are correct. If not reset to default """ + def validate_kwargs(self): + """ Validate that config options are correct. If not reset to default """ valid = True - if kwargs['minsize'] < 10: - valid = False - elif len(kwargs['threshold']) != 3: + threshold = [self.config["threshold_1"], + self.config["threshold_2"], + self.config["threshold_3"]] + kwargs = {"minsize": self.config["minsize"], + "threshold": threshold, + "factor": self.config["scalefactor"]} + + if kwargs["minsize"] < 10: valid = False - elif not all(0.0 < threshold < 1.0 - for threshold in kwargs['threshold']): + elif not all(0.0 < threshold <= 1.0 for threshold in kwargs['threshold']): valid = False elif not 0.0 < kwargs['factor'] < 1.0: valid = False @@ -53,7 +56,7 @@ def validate_kwargs(kwargs): kwargs = {"minsize": 20, # minimum size of face "threshold": [0.6, 0.7, 0.7], # three steps threshold "factor": 0.709} # scale factor - logger.warning("Invalid MTCNN arguments received. Running with defaults") + logger.warning("Invalid MTCNN options in config. Running with defaults") logger.debug("Using mtcnn kwargs: %s", kwargs) return kwargs @@ -72,7 +75,6 @@ def initialize(self, *args, **kwargs): super().initialize(*args, **kwargs) logger.info("Initializing MTCNN Detector...") is_gpu = False - self.kwargs = kwargs["mtcnn_kwargs"] # Must import tensorflow inside the spawned process # for Windows machines diff --git a/plugins/model/Model_GAN/Model.py b/plugins/model/Model_GAN/Model.py deleted file mode 100644 index 820652c187..0000000000 --- a/plugins/model/Model_GAN/Model.py +++ /dev/null @@ -1,187 +0,0 @@ -# Based on the https://github.com/shaoanlu/faceswap-GAN repo (master/temp/faceswap_GAN_keras.ipynb) - -import logging - -from keras.models import Model -from keras.layers import * -from keras.layers.advanced_activations import LeakyReLU -from keras.activations import relu -from keras.initializers import RandomNormal -from keras.applications import * -from keras.optimizers import Adam - -from lib.PixelShuffler import PixelShuffler -from .instance_normalization import InstanceNormalization -from lib.utils import backup_file - -from keras.utils import multi_gpu_model - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -hdf = {'netGAH5': 'netGA_GAN.h5', - 'netGBH5': 'netGB_GAN.h5', - 'netDAH5': 'netDA_GAN.h5', - 'netDBH5': 'netDB_GAN.h5'} - -def __conv_init(a): - logger.info("conv_init %s", a) - k = RandomNormal(0, 0.02)(a) # for convolution kernel - k.conv_weight = True - return k - -#def batchnorm(): -# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init) - -def inst_norm(): - return InstanceNormalization() - -conv_init = RandomNormal(0, 0.02) -gamma_init = RandomNormal(1., 0.02) # for batch normalization - -class GANModel(): - img_size = 64 - channels = 3 - img_shape = (img_size, img_size, channels) - encoded_dim = 1024 - nc_in = 3 # number of input channels of generators - nc_D_inp = 6 # number of input channels of discriminators - - def __init__(self, model_dir, gpus): - self.model_dir = model_dir - self.gpus = gpus - - optimizer = Adam(1e-4, 0.5) - - # Build and compile the discriminator - self.netDA, self.netDB = self.build_discriminator() - - # Build and compile the generator - self.netGA, self.netGB = self.build_generator() - - def converter(self, swap): - predictor = self.netGB if not swap else self.netGA - return lambda img: predictor.predict(img) - - def build_generator(self): - - def conv_block(input_tensor, f): - x = input_tensor - x = Conv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x) - x = Activation("relu")(x) - return x - - def res_block(input_tensor, f): - x = input_tensor - x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x) - x = LeakyReLU(alpha=0.2)(x) - x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same")(x) - x = add([x, input_tensor]) - x = LeakyReLU(alpha=0.2)(x) - return x - - def upscale_ps(filters, use_instance_norm=True): - def block(x): - x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def Encoder(nc_in=3, input_size=64): - inp = Input(shape=(input_size, input_size, nc_in)) - x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp) - x = conv_block(x,128) - x = conv_block(x,256) - x = conv_block(x,512) - x = conv_block(x,1024) - x = Dense(1024)(Flatten()(x)) - x = Dense(4*4*1024)(x) - x = Reshape((4, 4, 1024))(x) - out = upscale_ps(512)(x) - return Model(inputs=inp, outputs=out) - - def Decoder_ps(nc_in=512, input_size=8): - input_ = Input(shape=(input_size, input_size, nc_in)) - x = input_ - x = upscale_ps(256)(x) - x = upscale_ps(128)(x) - x = upscale_ps(64)(x) - x = res_block(x, 64) - x = res_block(x, 64) - #x = Conv2D(4, kernel_size=5, padding='same')(x) - alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x) - rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x) - out = concatenate([alpha, rgb]) - return Model(input_, out ) - - encoder = Encoder() - decoder_A = Decoder_ps() - decoder_B = Decoder_ps() - x = Input(shape=self.img_shape) - netGA = Model(x, decoder_A(encoder(x))) - netGB = Model(x, decoder_B(encoder(x))) - - self.netGA_sm = netGA - self.netGB_sm = netGB - - try: - netGA.load_weights(str(self.model_dir / hdf['netGAH5'])) - netGB.load_weights(str(self.model_dir / hdf['netGBH5'])) - logger.info("Generator models loaded.") - except: - logger.info("Generator weights files not found.") - pass - - if self.gpus > 1: - netGA = multi_gpu_model( self.netGA_sm , self.gpus) - netGB = multi_gpu_model( self.netGB_sm , self.gpus) - - return netGA, netGB - - def build_discriminator(self): - def conv_block_d(input_tensor, f, use_instance_norm=True): - x = input_tensor - x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x) - x = LeakyReLU(alpha=0.2)(x) - return x - - def Discriminator(nc_in, input_size=64): - inp = Input(shape=(input_size, input_size, nc_in)) - #x = GaussianNoise(0.05)(inp) - x = conv_block_d(inp, 64, False) - x = conv_block_d(x, 128, False) - x = conv_block_d(x, 256, False) - out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x) - return Model(inputs=[inp], outputs=out) - - netDA = Discriminator(self.nc_D_inp) - netDB = Discriminator(self.nc_D_inp) - try: - netDA.load_weights(str(self.model_dir / hdf['netDAH5'])) - netDB.load_weights(str(self.model_dir / hdf['netDBH5'])) - logger.info("Discriminator models loaded.") - except: - logger.info("Discriminator weights files not found.") - pass - return netDA, netDB - - def load(self, swapped): - if swapped: - logger.warning("swapping not supported on GAN") - # TODO load is done in __init__ => look how to swap if possible - return True - - def save_weights(self): - model_dir = str(self.model_dir) - for model in hdf.values(): - backup_file(model_dir, model) - if self.gpus > 1: - self.netGA_sm.save_weights(str(self.model_dir / hdf['netGAH5'])) - self.netGB_sm.save_weights(str(self.model_dir / hdf['netGBH5'])) - else: - self.netGA.save_weights(str(self.model_dir / hdf['netGAH5'])) - self.netGB.save_weights(str(self.model_dir / hdf['netGBH5'])) - self.netDA.save_weights(str(self.model_dir / hdf['netDAH5'])) - self.netDB.save_weights(str(self.model_dir / hdf['netDBH5'])) - logger.info("Models saved.") diff --git a/plugins/model/Model_GAN/Trainer.py b/plugins/model/Model_GAN/Trainer.py deleted file mode 100644 index 28753b4ec0..0000000000 --- a/plugins/model/Model_GAN/Trainer.py +++ /dev/null @@ -1,260 +0,0 @@ -import time -import cv2 -import numpy as np - -from keras.layers import * -from tensorflow.contrib.distributions import Beta -import tensorflow as tf -from keras.optimizers import Adam -from keras import backend as K - -from lib.training_data import TrainingDataGenerator, stack_images - -class GANTrainingDataGenerator(TrainingDataGenerator): - def __init__(self, random_transform_args, coverage, scale, zoom): - super().__init__(random_transform_args, coverage, scale, zoom) - - def color_adjust(self, img): - return img / 255.0 * 2 - 1 - -class Trainer(): - random_transform_args = { - 'rotation_range': 20, - 'zoom_range': 0.1, - 'shift_range': 0.05, - 'random_flip': 0.5, - } - - def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss): - K.set_learning_phase(1) - - assert batch_size % 2 == 0, "batch_size must be an even number" - self.batch_size = batch_size - self.model = model - - self.use_lsgan = True - self.use_mixup = True - self.mixup_alpha = 0.2 - self.use_perceptual_loss = perceptual_loss - self.use_instancenorm = False - - self.lrD = 1e-4 # Discriminator learning rate - self.lrG = 1e-4 # Generator learning rate - - generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 1) - self.train_batchA = generator.minibatchAB(fn_A, batch_size) - self.train_batchB = generator.minibatchAB(fn_B, batch_size) - - self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0 - - self.setup() - - def setup(self): - distorted_A, fake_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA) - distorted_B, fake_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB) - real_A = Input(shape=self.model.img_shape) - real_B = Input(shape=self.model.img_shape) - - if self.use_lsgan: - self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target))) - else: - self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target)) - - # ========== Define Perceptual Loss Model========== - if self.use_perceptual_loss: - from keras.models import Model - from keras_vggface.vggface import VGGFace - vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3)) - vggface.trainable = False - out_size55 = vggface.layers[36].output - out_size28 = vggface.layers[78].output - out_size7 = vggface.layers[-2].output - vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7]) - vggface_feat.trainable = False - else: - vggface_feat = None - - #TODO check "Tips for mask refinement (optional after >15k iters)" => https://render.githubusercontent.com/view/ipynb?commit=87d6e7a28ce754acd38d885367b6ceb0be92ec54&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f7368616f616e6c752f66616365737761702d47414e2f383764366537613238636537353461636433386438383533363762366365623062653932656335342f46616365537761705f47414e5f76325f737a3132385f747261696e2e6970796e62&nwo=shaoanlu%2Ffaceswap-GAN&path=FaceSwap_GAN_v2_sz128_train.ipynb&repository_id=115182783&repository_type=Repository#Tips-for-mask-refinement-(optional-after-%3E15k-iters) - loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, distorted_A, vggface_feat) - loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, distorted_B, vggface_feat) - - loss_GA += 1e-3 * K.mean(K.abs(mask_A)) - loss_GB += 1e-3 * K.mean(K.abs(mask_B)) - - w_fo = 0.01 - loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1)) - loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2)) - loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1)) - loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2)) - - weightsDA = self.model.netDA.trainable_weights - weightsGA = self.model.netGA.trainable_weights - weightsDB = self.model.netDB.trainable_weights - weightsGB = self.model.netGB.trainable_weights - - # Adam(..).get_updates(...) - training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA) - self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates) - training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA) - self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates) - - training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB) - self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates) - training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB) - self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates) - - def first_order(self, x, axis=1): - img_nrows = x.shape[1] - img_ncols = x.shape[2] - if axis == 1: - return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :]) - elif axis == 2: - return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :]) - else: - return None - - def train_one_step(self, iter, viewer): - # --------------------- - # Train Discriminators - # --------------------- - - # Select a random half batch of images - epoch, warped_A, target_A = next(self.train_batchA) - epoch, warped_B, target_B = next(self.train_batchB) - - # Train dicriminators for one batch - errDA = self.netDA_train([warped_A, target_A]) - errDB = self.netDB_train([warped_B, target_B]) - - # Train generators for one batch - errGA = self.netGA_train([warped_A, target_A]) - errGB = self.netGB_train([warped_B, target_B]) - - # For calculating average losses - self.errDA_sum += errDA[0] - self.errDB_sum += errDB[0] - self.errGA_sum += errGA[0] - self.errGB_sum += errGB[0] - self.avg_counter += 1 - - print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f' - % (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter), - end='\r') - - if viewer is not None: - self.show_sample(viewer) - - def cycle_variables(self, netG): - distorted_input = netG.inputs[0] - fake_output = netG.outputs[0] - alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output) - rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output) - - masked_fake_output = alpha * rgb + (1-alpha) * distorted_input - - fn_generate = K.function([distorted_input], [masked_fake_output]) - fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])]) - fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])]) - fn_bgr = K.function([distorted_input], [rgb]) - return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr - - def define_loss(self, netD, real, fake_argb, distorted, vggface_feat=None): - alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb) - fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb) - fake = alpha * fake_rgb + (1-alpha) * distorted - - if self.use_mixup: - dist = Beta(self.mixup_alpha, self.mixup_alpha) - lam = dist.sample() - # ========== - mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted]) - # ========== - output_mixup = netD(mixup) - loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup)) - output_fake = netD(concatenate([fake, distorted])) # dummy - loss_G = .5 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup)) - else: - output_real = netD(concatenate([real, distorted])) # positive sample - output_fake = netD(concatenate([fake, distorted])) # negative sample - loss_D_real = self.loss_fn(output_real, K.ones_like(output_real)) - loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake)) - loss_D = loss_D_real + loss_D_fake - loss_G = .5 * self.loss_fn(output_fake, K.ones_like(output_fake)) - # ========== - loss_G += K.mean(K.abs(fake_rgb - real)) - # ========== - - # Edge loss (similar with total variation loss) - loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=1) - self.first_order(real, axis=1))) - loss_G += 1 * K.mean(K.abs(self.first_order(fake_rgb, axis=2) - self.first_order(real, axis=2))) - - - # Perceptual Loss - if not vggface_feat is None: - def preprocess_vggface(x): - x = (x + 1)/2 * 255 # channel order: BGR - #x[..., 0] -= 93.5940 - #x[..., 1] -= 104.7624 - #x[..., 2] -= 129. - x -= [91.4953, 103.8827, 131.0912] - return x - pl_params = (0.011, 0.11, 0.1919) - real_sz224 = tf.image.resize_images(real, [224, 224]) - real_sz224 = Lambda(preprocess_vggface)(real_sz224) - # ========== - fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224]) - fake_sz224 = Lambda(preprocess_vggface)(fake_sz224) - # ========== - real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224) - fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224) - loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7)) - loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28)) - loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55)) - - return loss_D, loss_G - - def show_sample(self, display_fn): - _, wA, tA = next(self.train_batchA) - _, wB, tB = next(self.train_batchB) - display_fn(self.showG(tA, tB, self.path_A, self.path_B), "masked") - display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "raw") - display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask") - # Reset the averages - self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0 - self.avg_counter = 0 - - def showG(self, test_A, test_B, path_A, path_B): - figure_A = np.stack([ - test_A, - np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])), - np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])), - ], axis=1 ) - figure_B = np.stack([ - test_B, - np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])), - np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])), - ], axis=1 ) - - figure = np.concatenate([figure_A, figure_B], axis=0 ) - figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:]) - figure = stack_images(figure) - figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8') - return figure - - def showG_mask(self, test_A, test_B, path_A, path_B): - figure_A = np.stack([ - test_A, - (np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1, - (np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1, - ], axis=1 ) - figure_B = np.stack([ - test_B, - (np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1, - (np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1, - ], axis=1 ) - - figure = np.concatenate([figure_A, figure_B], axis=0 ) - figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:]) - figure = stack_images(figure) - figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8') - return figure diff --git a/plugins/model/Model_GAN/__init__.py b/plugins/model/Model_GAN/__init__.py deleted file mode 100644 index 0722bbad07..0000000000 --- a/plugins/model/Model_GAN/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -__author__ = """Based on https://github.com/shaoanlu/""" -__version__ = '0.1.0' - -from .Model import GANModel as Model -from .Trainer import Trainer \ No newline at end of file diff --git a/plugins/model/Model_GAN/instance_normalization.py b/plugins/model/Model_GAN/instance_normalization.py deleted file mode 100644 index 9c0616af27..0000000000 --- a/plugins/model/Model_GAN/instance_normalization.py +++ /dev/null @@ -1,145 +0,0 @@ -from keras.engine import Layer, InputSpec -from keras import initializers, regularizers, constraints -from keras import backend as K -from keras.utils.generic_utils import get_custom_objects - -import numpy as np - - -class InstanceNormalization(Layer): - """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). - Normalize the activations of the previous layer at each step, - i.e. applies a transformation that maintains the mean activation - close to 0 and the activation standard deviation close to 1. - # Arguments - axis: Integer, the axis that should be normalized - (typically the features axis). - For instance, after a `Conv2D` layer with - `data_format="channels_first"`, - set `axis=1` in `InstanceNormalization`. - Setting `axis=None` will normalize all values in each instance of the batch. - Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. - If False, `beta` is ignored. - scale: If True, multiply by `gamma`. - If False, `gamma` is not used. - When the next layer is linear (also e.g. `nn.relu`), - this can be disabled since the scaling - will be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - beta_constraint: Optional constraint for the beta weight. - gamma_constraint: Optional constraint for the gamma weight. - # Input shape - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - # Output shape - Same shape as input. - # References - - [Layer Normalization](https://arxiv.org/abs/1607.06450) - - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) - """ - def __init__(self, - axis=None, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - beta_constraint=None, - gamma_constraint=None, - **kwargs): - super(InstanceNormalization, self).__init__(**kwargs) - self.supports_masking = True - self.axis = axis - self.epsilon = epsilon - self.center = center - self.scale = scale - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - self.beta_regularizer = regularizers.get(beta_regularizer) - self.gamma_regularizer = regularizers.get(gamma_regularizer) - self.beta_constraint = constraints.get(beta_constraint) - self.gamma_constraint = constraints.get(gamma_constraint) - - def build(self, input_shape): - ndim = len(input_shape) - if self.axis == 0: - raise ValueError('Axis cannot be zero') - - if (self.axis is not None) and (ndim == 2): - raise ValueError('Cannot specify axis for rank 1 tensor') - - self.input_spec = InputSpec(ndim=ndim) - - if self.axis is None: - shape = (1,) - else: - shape = (input_shape[self.axis],) - - if self.scale: - self.gamma = self.add_weight(shape=shape, - name='gamma', - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - constraint=self.gamma_constraint) - else: - self.gamma = None - if self.center: - self.beta = self.add_weight(shape=shape, - name='beta', - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - constraint=self.beta_constraint) - else: - self.beta = None - self.built = True - - def call(self, inputs, training=None): - input_shape = K.int_shape(inputs) - reduction_axes = list(range(0, len(input_shape))) - - if (self.axis is not None): - del reduction_axes[self.axis] - - del reduction_axes[0] - - mean = K.mean(inputs, reduction_axes, keepdims=True) - stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon - normed = (inputs - mean) / stddev - - broadcast_shape = [1] * len(input_shape) - if self.axis is not None: - broadcast_shape[self.axis] = input_shape[self.axis] - - if self.scale: - broadcast_gamma = K.reshape(self.gamma, broadcast_shape) - normed = normed * broadcast_gamma - if self.center: - broadcast_beta = K.reshape(self.beta, broadcast_shape) - normed = normed + broadcast_beta - return normed - - def get_config(self): - config = { - 'axis': self.axis, - 'epsilon': self.epsilon, - 'center': self.center, - 'scale': self.scale, - 'beta_initializer': initializers.serialize(self.beta_initializer), - 'gamma_initializer': initializers.serialize(self.gamma_initializer), - 'beta_regularizer': regularizers.serialize(self.beta_regularizer), - 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), - 'beta_constraint': constraints.serialize(self.beta_constraint), - 'gamma_constraint': constraints.serialize(self.gamma_constraint) - } - base_config = super(InstanceNormalization, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - -get_custom_objects().update({'InstanceNormalization': InstanceNormalization}) diff --git a/plugins/model/Model_GAN128/Model.py b/plugins/model/Model_GAN128/Model.py deleted file mode 100644 index 57236360c5..0000000000 --- a/plugins/model/Model_GAN128/Model.py +++ /dev/null @@ -1,204 +0,0 @@ -# Based on the https://github.com/shaoanlu/faceswap-GAN repo -# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb -import logging - -from keras.models import Model -from keras.layers import * -from keras.layers.advanced_activations import LeakyReLU -from keras.activations import relu -from keras.initializers import RandomNormal -from keras.applications import * -from keras.optimizers import Adam - -from lib.PixelShuffler import PixelShuffler -from .instance_normalization import InstanceNormalization -from lib.utils import backup_file - -from keras.utils import multi_gpu_model - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -hdf = {'netGAH5':'netGA_GAN128.h5', - 'netGBH5': 'netGB_GAN128.h5', - 'netDAH5': 'netDA_GAN128.h5', - 'netDBH5': 'netDB_GAN128.h5'} - -def __conv_init(a): - logger.info("conv_init %s", a) - k = RandomNormal(0, 0.02)(a) # for convolution kernel - k.conv_weight = True - return k - -#def batchnorm(): -# return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init) - -def inst_norm(): - return InstanceNormalization() - -conv_init = RandomNormal(0, 0.02) -gamma_init = RandomNormal(1., 0.02) # for batch normalization - -class GANModel(): - img_size = 128 - channels = 3 - img_shape = (img_size, img_size, channels) - encoded_dim = 1024 - nc_in = 3 # number of input channels of generators - nc_D_inp = 6 # number of input channels of discriminators - - def __init__(self, model_dir, gpus): - self.model_dir = model_dir - self.gpus = gpus - - optimizer = Adam(1e-4, 0.5) - - # Build and compile the discriminator - self.netDA, self.netDB = self.build_discriminator() - - # Build and compile the generator - self.netGA, self.netGB = self.build_generator() - - def converter(self, swap): - predictor = self.netGB if not swap else self.netGA - return lambda img: predictor.predict(img) - - def build_generator(self): - - def conv_block(input_tensor, f, use_instance_norm=True): - x = input_tensor - x = SeparableConv2D(f, kernel_size=3, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x) - if use_instance_norm: - x = inst_norm()(x) - x = Activation("relu")(x) - return x - - def res_block(input_tensor, f, dilation=1): - x = input_tensor - x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x) - x = LeakyReLU(alpha=0.2)(x) - x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init, use_bias=False, padding="same", dilation_rate=dilation)(x) - x = add([x, input_tensor]) - #x = LeakyReLU(alpha=0.2)(x) - return x - - def upscale_ps(filters, use_instance_norm=True): - def block(x, use_instance_norm=use_instance_norm): - x = Conv2D(filters*4, kernel_size=3, use_bias=False, kernel_initializer=RandomNormal(0, 0.02), padding='same')(x) - if use_instance_norm: - x = inst_norm()(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def Encoder(nc_in=3, input_size=128): - inp = Input(shape=(input_size, input_size, nc_in)) - x = Conv2D(32, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp) - x = conv_block(x,64, use_instance_norm=False) - x = conv_block(x,128) - x = conv_block(x,256) - x = conv_block(x,512) - x = conv_block(x,1024) - x = Dense(1024)(Flatten()(x)) - x = Dense(4*4*1024)(x) - x = Reshape((4, 4, 1024))(x) - out = upscale_ps(512)(x) - return Model(inputs=inp, outputs=out) - - def Decoder_ps(nc_in=512, input_size=8): - input_ = Input(shape=(input_size, input_size, nc_in)) - x = input_ - x = upscale_ps(256)(x) - x = upscale_ps(128)(x) - x = upscale_ps(64)(x) - x = res_block(x, 64, dilation=2) - - out64 = Conv2D(64, kernel_size=3, padding='same')(x) - out64 = LeakyReLU(alpha=0.1)(out64) - out64 = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(out64) - - x = upscale_ps(32)(x) - x = res_block(x, 32) - x = res_block(x, 32) - alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x) - rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x) - out = concatenate([alpha, rgb]) - return Model(input_, [out, out64] ) - - encoder = Encoder() - decoder_A = Decoder_ps() - decoder_B = Decoder_ps() - x = Input(shape=self.img_shape) - netGA = Model(x, decoder_A(encoder(x))) - netGB = Model(x, decoder_B(encoder(x))) - netGA.output_names = ["netGA_out_1", "netGA_out_2"] # Workarounds till https://github.com/keras-team/keras/issues/8962 is fixed. - netGB.output_names = ["netGB_out_1", "netGB_out_2"] # - - self.netGA_sm = netGA - self.netGB_sm = netGB - - try: - netGA.load_weights(str(self.model_dir / hdf['netGAH5'])) - netGB.load_weights(str(self.model_dir / hdf['netGBH5'])) - logger.info("Generator models loaded.") - except: - logger.info("Generator weights files not found.") - pass - - if self.gpus > 1: - netGA = multi_gpu_model( self.netGA_sm , self.gpus) - netGB = multi_gpu_model( self.netGB_sm , self.gpus) - - return netGA, netGB - - def build_discriminator(self): - def conv_block_d(input_tensor, f, use_instance_norm=True): - x = input_tensor - x = Conv2D(f, kernel_size=4, strides=2, kernel_initializer=conv_init, use_bias=False, padding="same")(x) - if use_instance_norm: - x = inst_norm()(x) - x = LeakyReLU(alpha=0.2)(x) - return x - - def Discriminator(nc_in, input_size=128): - inp = Input(shape=(input_size, input_size, nc_in)) - #x = GaussianNoise(0.05)(inp) - x = conv_block_d(inp, 64, False) - x = conv_block_d(x, 128, True) - x = conv_block_d(x, 256, True) - x = conv_block_d(x, 512, True) - out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same", activation="sigmoid")(x) - return Model(inputs=[inp], outputs=out) - - netDA = Discriminator(self.nc_D_inp) - netDB = Discriminator(self.nc_D_inp) - - try: - netDA.load_weights(str(self.model_dir / hdf['netDAH5'])) - netDB.load_weights(str(self.model_dir / hdf['netDBH5'])) - logger.info("Discriminator models loaded.") - except: - logger.info("Discriminator weights files not found.") - pass - return netDA, netDB - - def load(self, swapped): - if swapped: - logger.warning("swapping not supported on GAN") - # TODO load is done in __init__ => look how to swap if possible - return True - - def save_weights(self): - model_dir = str(self.model_dir) - for model in hdf.values(): - backup_file(model_dir, model) - if self.gpus > 1: - self.netGA_sm.save_weights(str(self.model_dir / hdf['netGAH5'])) - self.netGB_sm.save_weights(str(self.model_dir / hdf['netGBH5'])) - else: - self.netGA.save_weights(str(self.model_dir / hdf['netGAH5'])) - self.netGB.save_weights(str(self.model_dir / hdf['netGBH5'])) - self.netDA.save_weights(str(self.model_dir / hdf['netDAH5'])) - self.netDB.save_weights(str(self.model_dir / hdf['netDBH5'])) - logger.info("Models saved.") diff --git a/plugins/model/Model_GAN128/Trainer.py b/plugins/model/Model_GAN128/Trainer.py deleted file mode 100644 index 8fc1b36aaf..0000000000 --- a/plugins/model/Model_GAN128/Trainer.py +++ /dev/null @@ -1,263 +0,0 @@ -import time -import cv2 -import numpy as np - -from keras.layers import * -from tensorflow.contrib.distributions import Beta -import tensorflow as tf -from keras.optimizers import Adam -from keras import backend as K - -from lib.training_data import TrainingDataGenerator, stack_images - -class GANTrainingDataGenerator(TrainingDataGenerator): - def __init__(self, random_transform_args, coverage, scale, zoom): - super().__init__(random_transform_args, coverage, scale, zoom) - - def color_adjust(self, img): - return img / 255.0 * 2 - 1 - -class Trainer(): - random_transform_args = { - 'rotation_range': 20, - 'zoom_range': 0.1, - 'shift_range': 0.05, - 'random_flip': 0.5, - } - - def __init__(self, model, fn_A, fn_B, batch_size, perceptual_loss): - K.set_learning_phase(1) - - assert batch_size % 2 == 0, "batch_size must be an even number" - self.batch_size = batch_size - self.model = model - - self.use_lsgan = True - self.use_mixup = True - self.mixup_alpha = 0.2 - self.use_perceptual_loss = perceptual_loss - self.use_mask_refinement = False #OPTIONAL After 15k iteration** - - self.lrD = 1e-4 # Discriminator learning rate - self.lrG = 1e-4 # Generator learning rate - - generator = GANTrainingDataGenerator(self.random_transform_args, 220, 6, 2) - self.train_batchA = generator.minibatchAB(fn_A, batch_size) - self.train_batchB = generator.minibatchAB(fn_B, batch_size) - - self.avg_counter = self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0 - - self.setup() - - def setup(self): - distorted_A, fake_A, fake_sz64_A, mask_A, self.path_A, self.path_mask_A, self.path_abgr_A, self.path_bgr_A = self.cycle_variables(self.model.netGA) - distorted_B, fake_B, fake_sz64_B, mask_B, self.path_B, self.path_mask_B, self.path_abgr_B, self.path_bgr_B = self.cycle_variables(self.model.netGB) - real_A = Input(shape=self.model.img_shape) - real_B = Input(shape=self.model.img_shape) - - if self.use_lsgan: - self.loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target))) - else: - self.loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target)) - - # ========== Define Perceptual Loss Model========== - if self.use_perceptual_loss: - from keras.models import Model - from keras_vggface.vggface import VGGFace - vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3)) - vggface.trainable = False - out_size55 = vggface.layers[36].output - out_size28 = vggface.layers[78].output - out_size7 = vggface.layers[-2].output - vggface_feat = Model(vggface.input, [out_size55, out_size28, out_size7]) - vggface_feat.trainable = False - else: - vggface_feat = None - - loss_DA, loss_GA = self.define_loss(self.model.netDA, real_A, fake_A, fake_sz64_A, distorted_A, vggface_feat) - loss_DB, loss_GB = self.define_loss(self.model.netDB, real_B, fake_B, fake_sz64_B, distorted_B, vggface_feat) - - if self.use_mask_refinement: - loss_GA += 1e-3 * K.mean(K.square(mask_A)) - loss_GB += 1e-3 * K.mean(K.square(mask_B)) - else: - loss_GA += 3e-3 * K.mean(K.abs(mask_A)) - loss_GB += 3e-3 * K.mean(K.abs(mask_B)) - - w_fo = 0.01 - loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=1)) - loss_GA += w_fo * K.mean(self.first_order(mask_A, axis=2)) - loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=1)) - loss_GB += w_fo * K.mean(self.first_order(mask_B, axis=2)) - - weightsDA = self.model.netDA.trainable_weights - weightsGA = self.model.netGA.trainable_weights - weightsDB = self.model.netDB.trainable_weights - weightsGB = self.model.netGB.trainable_weights - - # Adam(..).get_updates(...) - training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA) - self.netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates) - training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGA,[], loss_GA) - self.netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates) - - training_updates = Adam(lr=self.lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB) - self.netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates) - training_updates = Adam(lr=self.lrG, beta_1=0.5).get_updates(weightsGB,[], loss_GB) - self.netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates) - - def first_order(self, x, axis=1): - img_nrows = x.shape[1] - img_ncols = x.shape[2] - if axis == 1: - return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :]) - elif axis == 2: - return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :]) - else: - return None - - def train_one_step(self, iter, viewer): - # --------------------- - # Train Discriminators - # --------------------- - - # Select a random half batch of images - epoch, warped_A, target_A = next(self.train_batchA) - epoch, warped_B, target_B = next(self.train_batchB) - - # Train dicriminators for one batch - errDA = self.netDA_train([warped_A, target_A]) - errDB = self.netDB_train([warped_B, target_B]) - - # Train generators for one batch - errGA = self.netGA_train([warped_A, target_A]) - errGB = self.netGB_train([warped_B, target_B]) - - # For calculating average losses - self.errDA_sum += errDA[0] - self.errDB_sum += errDB[0] - self.errGA_sum += errGA[0] - self.errGB_sum += errGB[0] - self.avg_counter += 1 - - print('[%s] [%d/%s][%d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f' - % (time.strftime("%H:%M:%S"), epoch, "num_epochs", iter, self.errDA_sum/self.avg_counter, self.errDB_sum/self.avg_counter, self.errGA_sum/self.avg_counter, self.errGB_sum/self.avg_counter), - end='\r') - - if viewer is not None: - self.show_sample(viewer) - - def cycle_variables(self, netG): - distorted_input = netG.inputs[0] - fake_output = netG.outputs[0] - fake_output64 = netG.outputs[1] - alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output) - rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output) - - masked_fake_output = alpha * rgb + (1-alpha) * distorted_input - - fn_generate = K.function([distorted_input], [masked_fake_output]) - fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])]) - fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])]) - fn_bgr = K.function([distorted_input], [rgb]) - return distorted_input, fake_output, fake_output64, alpha, fn_generate, fn_mask, fn_abgr, fn_bgr - - def define_loss(self, netD, real, fake_argb, fake_sz64, distorted, vggface_feat=None): - alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb) - fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb) - fake = alpha * fake_rgb + (1-alpha) * distorted - - if self.use_mixup: - dist = Beta(self.mixup_alpha, self.mixup_alpha) - lam = dist.sample() - # ========== - mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted]) - # ========== - output_mixup = netD(mixup) - loss_D = self.loss_fn(output_mixup, lam * K.ones_like(output_mixup)) - #output_fake = netD(concatenate([fake, distorted])) # dummy - loss_G = 1 * self.loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup)) - else: - output_real = netD(concatenate([real, distorted])) # positive sample - output_fake = netD(concatenate([fake, distorted])) # negative sample - loss_D_real = self.loss_fn(output_real, K.ones_like(output_real)) - loss_D_fake = self.loss_fn(output_fake, K.zeros_like(output_fake)) - loss_D = loss_D_real + loss_D_fake - loss_G = 1 * self.loss_fn(output_fake, K.ones_like(output_fake)) - # ========== - if self.use_mask_refinement: - loss_G += K.mean(K.abs(fake - real)) - else: - loss_G += K.mean(K.abs(fake_rgb - real)) - loss_G += K.mean(K.abs(fake_sz64 - tf.image.resize_images(real, [64, 64]))) - # ========== - - # Perceptual Loss - if not vggface_feat is None: - def preprocess_vggface(x): - x = (x + 1)/2 * 255 # channel order: BGR - x -= [93.5940, 104.7624, 129.] - return x - pl_params = (0.02, 0.3, 0.5) - real_sz224 = tf.image.resize_images(real, [224, 224]) - real_sz224 = Lambda(preprocess_vggface)(real_sz224) - # ========== - if self.use_mask_refinement: - fake_sz224 = tf.image.resize_images(fake, [224, 224]) - else: - fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224]) - fake_sz224 = Lambda(preprocess_vggface)(fake_sz224) - # ========== - real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224) - fake_feat55, fake_feat28, fake_feat7 = vggface_feat(fake_sz224) - loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7)) - loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28)) - loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55)) - - return loss_D, loss_G - - def show_sample(self, display_fn): - _, wA, tA = next(self.train_batchA) - _, wB, tB = next(self.train_batchB) - display_fn(self.showG(tA, tB, self.path_A, self.path_B), "masked") - display_fn(self.showG(tA, tB, self.path_bgr_A, self.path_bgr_B), "raw") - display_fn(self.showG_mask(tA, tB, self.path_mask_A, self.path_mask_B), "mask") - # Reset the averages - self.errDA_sum = self.errDB_sum = self.errGA_sum = self.errGB_sum = 0 - self.avg_counter = 0 - - def showG(self, test_A, test_B, path_A, path_B): - figure_A = np.stack([ - test_A, - np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])), - np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])), - ], axis=1 ) - figure_B = np.stack([ - test_B, - np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])), - np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])), - ], axis=1 ) - - figure = np.concatenate([figure_A, figure_B], axis=0 ) - figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:]) - figure = stack_images(figure) - figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8') - return figure - - def showG_mask(self, test_A, test_B, path_A, path_B): - figure_A = np.stack([ - test_A, - (np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1, - (np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1, - ], axis=1 ) - figure_B = np.stack([ - test_B, - (np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1, - (np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1, - ], axis=1 ) - - figure = np.concatenate([figure_A, figure_B], axis=0 ) - figure = figure.reshape((4,self.batch_size // 2) + figure.shape[1:]) - figure = stack_images(figure) - figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8') - return figure diff --git a/plugins/model/Model_GAN128/__init__.py b/plugins/model/Model_GAN128/__init__.py deleted file mode 100644 index 0722bbad07..0000000000 --- a/plugins/model/Model_GAN128/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -__author__ = """Based on https://github.com/shaoanlu/""" -__version__ = '0.1.0' - -from .Model import GANModel as Model -from .Trainer import Trainer \ No newline at end of file diff --git a/plugins/model/Model_GAN128/instance_normalization.py b/plugins/model/Model_GAN128/instance_normalization.py deleted file mode 100644 index 9c0616af27..0000000000 --- a/plugins/model/Model_GAN128/instance_normalization.py +++ /dev/null @@ -1,145 +0,0 @@ -from keras.engine import Layer, InputSpec -from keras import initializers, regularizers, constraints -from keras import backend as K -from keras.utils.generic_utils import get_custom_objects - -import numpy as np - - -class InstanceNormalization(Layer): - """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). - Normalize the activations of the previous layer at each step, - i.e. applies a transformation that maintains the mean activation - close to 0 and the activation standard deviation close to 1. - # Arguments - axis: Integer, the axis that should be normalized - (typically the features axis). - For instance, after a `Conv2D` layer with - `data_format="channels_first"`, - set `axis=1` in `InstanceNormalization`. - Setting `axis=None` will normalize all values in each instance of the batch. - Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. - If False, `beta` is ignored. - scale: If True, multiply by `gamma`. - If False, `gamma` is not used. - When the next layer is linear (also e.g. `nn.relu`), - this can be disabled since the scaling - will be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - beta_constraint: Optional constraint for the beta weight. - gamma_constraint: Optional constraint for the gamma weight. - # Input shape - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - # Output shape - Same shape as input. - # References - - [Layer Normalization](https://arxiv.org/abs/1607.06450) - - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) - """ - def __init__(self, - axis=None, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - beta_constraint=None, - gamma_constraint=None, - **kwargs): - super(InstanceNormalization, self).__init__(**kwargs) - self.supports_masking = True - self.axis = axis - self.epsilon = epsilon - self.center = center - self.scale = scale - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - self.beta_regularizer = regularizers.get(beta_regularizer) - self.gamma_regularizer = regularizers.get(gamma_regularizer) - self.beta_constraint = constraints.get(beta_constraint) - self.gamma_constraint = constraints.get(gamma_constraint) - - def build(self, input_shape): - ndim = len(input_shape) - if self.axis == 0: - raise ValueError('Axis cannot be zero') - - if (self.axis is not None) and (ndim == 2): - raise ValueError('Cannot specify axis for rank 1 tensor') - - self.input_spec = InputSpec(ndim=ndim) - - if self.axis is None: - shape = (1,) - else: - shape = (input_shape[self.axis],) - - if self.scale: - self.gamma = self.add_weight(shape=shape, - name='gamma', - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - constraint=self.gamma_constraint) - else: - self.gamma = None - if self.center: - self.beta = self.add_weight(shape=shape, - name='beta', - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - constraint=self.beta_constraint) - else: - self.beta = None - self.built = True - - def call(self, inputs, training=None): - input_shape = K.int_shape(inputs) - reduction_axes = list(range(0, len(input_shape))) - - if (self.axis is not None): - del reduction_axes[self.axis] - - del reduction_axes[0] - - mean = K.mean(inputs, reduction_axes, keepdims=True) - stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon - normed = (inputs - mean) / stddev - - broadcast_shape = [1] * len(input_shape) - if self.axis is not None: - broadcast_shape[self.axis] = input_shape[self.axis] - - if self.scale: - broadcast_gamma = K.reshape(self.gamma, broadcast_shape) - normed = normed * broadcast_gamma - if self.center: - broadcast_beta = K.reshape(self.beta, broadcast_shape) - normed = normed + broadcast_beta - return normed - - def get_config(self): - config = { - 'axis': self.axis, - 'epsilon': self.epsilon, - 'center': self.center, - 'scale': self.scale, - 'beta_initializer': initializers.serialize(self.beta_initializer), - 'gamma_initializer': initializers.serialize(self.gamma_initializer), - 'beta_regularizer': regularizers.serialize(self.beta_regularizer), - 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), - 'beta_constraint': constraints.serialize(self.beta_constraint), - 'gamma_constraint': constraints.serialize(self.gamma_constraint) - } - base_config = super(InstanceNormalization, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - -get_custom_objects().update({'InstanceNormalization': InstanceNormalization}) diff --git a/plugins/model/Model_IAE/AutoEncoder.py b/plugins/model/Model_IAE/AutoEncoder.py deleted file mode 100644 index 3f012dbd2e..0000000000 --- a/plugins/model/Model_IAE/AutoEncoder.py +++ /dev/null @@ -1,53 +0,0 @@ -# Improved-AutoEncoder base classes - -import logging - -from lib.utils import backup_file - -hdf = {'encoderH5': 'IAE_encoder.h5', - 'decoderH5': 'IAE_decoder.h5', - 'inter_AH5': 'IAE_inter_A.h5', - 'inter_BH5': 'IAE_inter_B.h5', - 'inter_bothH5': 'IAE_inter_both.h5'} - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class AutoEncoder: - def __init__(self, model_dir, gpus): - self.model_dir = model_dir - self.gpus = gpus - - self.encoder = self.Encoder() - self.decoder = self.Decoder() - self.inter_A = self.Intermidiate() - self.inter_B = self.Intermidiate() - self.inter_both = self.Intermidiate() - - self.initModel() - - def load(self, swapped): - (face_A,face_B) = (hdf['inter_AH5'], hdf['inter_BH5']) if not swapped else (hdf['inter_BH5'], hdf['inter_AH5']) - - try: - self.encoder.load_weights(str(self.model_dir / hdf['encoderH5'])) - self.decoder.load_weights(str(self.model_dir / hdf['decoderH5'])) - self.inter_both.load_weights(str(self.model_dir / hdf['inter_bothH5'])) - self.inter_A.load_weights(str(self.model_dir / face_A)) - self.inter_B.load_weights(str(self.model_dir / face_B)) - logger.info('loaded model weights') - return True - except Exception: - logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir) - return False - - def save_weights(self): - model_dir = str(self.model_dir) - for model in hdf.values(): - backup_file(model_dir, model) - self.encoder.save_weights(str(self.model_dir / hdf['encoderH5'])) - self.decoder.save_weights(str(self.model_dir / hdf['decoderH5'])) - self.inter_both.save_weights(str(self.model_dir / hdf['inter_bothH5'])) - self.inter_A.save_weights(str(self.model_dir / hdf['inter_AH5'])) - self.inter_B.save_weights(str(self.model_dir / hdf['inter_BH5'])) - logger.info('saved model weights') diff --git a/plugins/model/Model_IAE/Model.py b/plugins/model/Model_IAE/Model.py deleted file mode 100644 index 65f51de582..0000000000 --- a/plugins/model/Model_IAE/Model.py +++ /dev/null @@ -1,77 +0,0 @@ -# Improved autoencoder for faceswap. - -from keras.models import Model as KerasModel -from keras.layers import Input, Dense, Flatten, Reshape, Concatenate -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import Conv2D -from keras.optimizers import Adam - -from .AutoEncoder import AutoEncoder -from lib.PixelShuffler import PixelShuffler - -from keras.utils import multi_gpu_model - -IMAGE_SHAPE = (64, 64, 3) -ENCODER_DIM = 1024 - -class Model(AutoEncoder): - def initModel(self): - optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - x = Input(shape=IMAGE_SHAPE) - - self.autoencoder_A = KerasModel(x, self.decoder(Concatenate()([self.inter_A(self.encoder(x)), self.inter_both(self.encoder(x))]))) - self.autoencoder_B = KerasModel(x, self.decoder(Concatenate()([self.inter_B(self.encoder(x)), self.inter_both(self.encoder(x))]))) - - if self.gpus > 1: - self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus) - self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus) - - self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error') - self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error') - - def converter(self, swap): - autoencoder = self.autoencoder_B if not swap else self.autoencoder_A - return lambda img: autoencoder.predict(img) - - def conv(self, filters): - def block(x): - x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x) - x = LeakyReLU(0.1)(x) - return x - return block - - def upscale(self, filters): - def block(x): - x = Conv2D(filters * 4, kernel_size=3, padding='same')(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def Encoder(self): - input_ = Input(shape=IMAGE_SHAPE) - x = input_ - x = self.conv(128)(x) - x = self.conv(256)(x) - x = self.conv(512)(x) - x = self.conv(1024)(x) - x = Flatten()(x) - return KerasModel(input_, x) - - def Intermidiate(self): - input_ = Input(shape=(None, 4 * 4 * 1024)) - x = input_ - x = Dense(ENCODER_DIM)(x) - x = Dense(4 * 4 * int(ENCODER_DIM/2))(x) - x = Reshape((4, 4, int(ENCODER_DIM/2)))(x) - return KerasModel(input_, x) - - def Decoder(self): - input_ = Input(shape=(4, 4, ENCODER_DIM)) - x = input_ - x = self.upscale(512)(x) - x = self.upscale(256)(x) - x = self.upscale(128)(x) - x = self.upscale(64)(x) - x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) - return KerasModel(input_, x) diff --git a/plugins/model/Model_IAE/Trainer.py b/plugins/model/Model_IAE/Trainer.py deleted file mode 100644 index 67f954db8f..0000000000 --- a/plugins/model/Model_IAE/Trainer.py +++ /dev/null @@ -1,51 +0,0 @@ - -import time -import numpy -from lib.training_data import TrainingDataGenerator, stack_images - - -class Trainer(): - random_transform_args = { - 'rotation_range': 10, - 'zoom_range': 0.05, - 'shift_range': 0.05, - 'random_flip': 0.4, - } - - def __init__(self, model, fn_A, fn_B, batch_size, *args): - self.batch_size = batch_size - self.model = model - - generator = TrainingDataGenerator(self.random_transform_args, 160) - self.images_A = generator.minibatchAB(fn_A, self.batch_size) - self.images_B = generator.minibatchAB(fn_B, self.batch_size) - - def train_one_step(self, iter, viewer): - epoch, warped_A, target_A = next(self.images_A) - epoch, warped_B, target_B = next(self.images_B) - - loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A) - loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B) - print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B), - end='\r') - - if viewer is not None: - viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training") - - def show_sample(self, test_A, test_B): - figure_A = numpy.stack([ - test_A, - self.model.autoencoder_A.predict(test_A), - self.model.autoencoder_B.predict(test_A), - ], axis=1) - figure_B = numpy.stack([ - test_B, - self.model.autoencoder_B.predict(test_B), - self.model.autoencoder_A.predict(test_B), - ], axis=1) - - figure = numpy.concatenate([figure_A, figure_B], axis=0) - figure = figure.reshape((4, 7) + figure.shape[1:]) - figure = stack_images(figure) - - return numpy.clip(figure * 255, 0, 255).astype('uint8') diff --git a/plugins/model/Model_IAE/__init__.py b/plugins/model/Model_IAE/__init__.py deleted file mode 100644 index f1151ee7a9..0000000000 --- a/plugins/model/Model_IAE/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- - -__author__ = """acsaga""" -__version__ = '0.1.0' - -from .Model import Model -from .Trainer import Trainer -from .AutoEncoder import AutoEncoder diff --git a/plugins/model/Model_LowMem/AutoEncoder.py b/plugins/model/Model_LowMem/AutoEncoder.py deleted file mode 100644 index 997d5643c9..0000000000 --- a/plugins/model/Model_LowMem/AutoEncoder.py +++ /dev/null @@ -1,61 +0,0 @@ -# AutoEncoder base classes -import logging - -from lib.utils import backup_file - -hdf = {'encoderH5': 'lowmem_encoder.h5', - 'decoder_AH5': 'lowmem_decoder_A.h5', - 'decoder_BH5': 'lowmem_decoder_B.h5'} - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - -#Part of Filename migration, should be remopved some reasonable time after first added -import os.path -old_encoderH5 = 'encoder.h5' -old_decoder_AH5 = 'decoder_A.h5' -old_decoder_BH5 = 'decoder_B.h5' -#End filename migration - -class AutoEncoder: - def __init__(self, model_dir, gpus): - self.model_dir = model_dir - self.gpus = gpus - - self.encoder = self.Encoder() - self.decoder_A = self.Decoder() - self.decoder_B = self.Decoder() - - self.initModel() - - def load(self, swapped): - (face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5']) - - try: - #Part of Filename migration, should be remopved some reasonable time after first added - if os.path.isfile(str(self.model_dir / old_encoderH5)): - logger.info('Migrating to new filenames:') - if os.path.isfile(str(self.model_dir / hdf['encoderH5'])) is not True: - os.rename(str(self.model_dir / old_decoder_AH5), str(self.model_dir / hdf['decoder_AH5'])) - os.rename(str(self.model_dir / old_decoder_BH5), str(self.model_dir / hdf['decoder_BH5'])) - os.rename(str(self.model_dir / old_encoderH5), str(self.model_dir / hdf['encoderH5'])) - logger.info('Complete') - else: - logger.warning('Failed due to existing files in folder. Loading already migrated files') - #End filename migration - self.encoder.load_weights(str(self.model_dir / hdf['encoderH5'])) - self.decoder_A.load_weights(str(self.model_dir / face_A)) - self.decoder_B.load_weights(str(self.model_dir / face_B)) - logger.info('loaded model weights') - return True - except Exception as e: - logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir) - return False - - def save_weights(self): - model_dir = str(self.model_dir) - for model in hdf.values(): - backup_file(model_dir, model) - self.encoder.save_weights(str(self.model_dir / hdf['encoderH5'])) - self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5'])) - self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5'])) - logger.info('saved model weights') diff --git a/plugins/model/Model_LowMem/Model.py b/plugins/model/Model_LowMem/Model.py deleted file mode 100644 index 506e6e9167..0000000000 --- a/plugins/model/Model_LowMem/Model.py +++ /dev/null @@ -1,70 +0,0 @@ -# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs - -from keras.models import Model as KerasModel -from keras.layers import Input, Dense, Flatten, Reshape -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import Conv2D -from keras.optimizers import Adam - -from .AutoEncoder import AutoEncoder -from lib.PixelShuffler import PixelShuffler - -from keras.utils import multi_gpu_model - -IMAGE_SHAPE = (64, 64, 3) -ENCODER_DIM = 512 - -class Model(AutoEncoder): - def initModel(self): - optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - x = Input(shape=IMAGE_SHAPE) - - self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x))) - self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x))) - - if self.gpus > 1: - self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus) - self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus) - - self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error') - self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error') - - def converter(self, swap): - autoencoder = self.autoencoder_B if not swap else self.autoencoder_A - return lambda img: autoencoder.predict(img) - - def conv(self, filters): - def block(x): - x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x) - x = LeakyReLU(0.1)(x) - return x - return block - - def upscale(self, filters): - def block(x): - x = Conv2D(filters * 4, kernel_size=3, padding='same')(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def Encoder(self): - input_ = Input(shape=IMAGE_SHAPE) - x = input_ - x = self.conv(128)(x) - x = self.conv(256)(x) - x = self.conv(512)(x) - x = Dense(ENCODER_DIM)(Flatten()(x)) - x = Dense(4 * 4 * 1024)(x) - x = Reshape((4, 4, 1024))(x) - x = self.upscale(512)(x) - return KerasModel(input_, x) - - def Decoder(self): - input_ = Input(shape=(8, 8, 512)) - x = input_ - x = self.upscale(256)(x) - x = self.upscale(128)(x) - x = self.upscale(64)(x) - x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) - return KerasModel(input_, x) diff --git a/plugins/model/Model_LowMem/Trainer.py b/plugins/model/Model_LowMem/Trainer.py deleted file mode 100644 index 4f55b399e1..0000000000 --- a/plugins/model/Model_LowMem/Trainer.py +++ /dev/null @@ -1,56 +0,0 @@ - -import time -import numpy -from lib.training_data import TrainingDataGenerator, stack_images - -class Trainer(): - random_transform_args = { - 'rotation_range': 10, - 'zoom_range': 0.05, - 'shift_range': 0.05, - 'random_flip': 0.4, - } - - def __init__(self, model, fn_A, fn_B, batch_size, *args): - self.batch_size = batch_size - self.model = model - - generator = TrainingDataGenerator(self.random_transform_args, 160) - self.images_A = generator.minibatchAB(fn_A, self.batch_size) - self.images_B = generator.minibatchAB(fn_B, self.batch_size) - - def train_one_step(self, iter, viewer): - epoch, warped_A, target_A = next(self.images_A) - epoch, warped_B, target_B = next(self.images_B) - - loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A) - loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B) - print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), iter, loss_A, loss_B), - end='\r') - - if viewer is not None: - viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training") - - def show_sample(self, test_A, test_B): - figure_A = numpy.stack([ - test_A, - self.model.autoencoder_A.predict(test_A), - self.model.autoencoder_B.predict(test_A), - ], axis=1) - figure_B = numpy.stack([ - test_B, - self.model.autoencoder_B.predict(test_B), - self.model.autoencoder_A.predict(test_B), - ], axis=1) - - if test_A.shape[0] % 2 == 1: - figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ]) - figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ]) - - figure = numpy.concatenate([figure_A, figure_B], axis=0) - w = 4 - h = int( figure.shape[0] / w) - figure = figure.reshape((w, h) + figure.shape[1:]) - figure = stack_images(figure) - - return numpy.clip(figure * 255, 0, 255).astype('uint8') diff --git a/plugins/model/Model_LowMem/__init__.py b/plugins/model/Model_LowMem/__init__.py deleted file mode 100644 index 82d540fe2e..0000000000 --- a/plugins/model/Model_LowMem/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- - -__author__ = """Based on https://reddit.com/u/deepfakes/""" -__version__ = '0.1.0' - -from .Model import Model -from .Trainer import Trainer -from .AutoEncoder import AutoEncoder \ No newline at end of file diff --git a/plugins/model/Model_Original/AutoEncoder.py b/plugins/model/Model_Original/AutoEncoder.py deleted file mode 100644 index a0c4f795e9..0000000000 --- a/plugins/model/Model_Original/AutoEncoder.py +++ /dev/null @@ -1,77 +0,0 @@ -# AutoEncoder base classes -import logging - -from lib.utils import backup_file -from lib import Serializer -from json import JSONDecodeError - -hdf = {'encoderH5': 'encoder.h5', - 'decoder_AH5': 'decoder_A.h5', - 'decoder_BH5': 'decoder_B.h5', - 'state': 'state'} -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - - -class AutoEncoder: - def __init__(self, model_dir, gpus): - self.model_dir = model_dir - self.gpus = gpus - - self.encoder = self.Encoder() - self.decoder_A = self.Decoder() - self.decoder_B = self.Decoder() - - self.initModel() - - def load(self, swapped): - serializer = Serializer.get_serializer('json') - state_fn = ".".join([hdf['state'], serializer.ext]) - try: - with open(str(self.model_dir / state_fn), 'rb') as fp: - state = serializer.unmarshal(fp.read().decode('utf-8')) - self._epoch_no = state['epoch_no'] - except IOError as e: - logger.warning('Error loading training info: %s', str(e.strerror)) - self._epoch_no = 0 - except JSONDecodeError as e: - logger.warning('Error loading training info: %s', str(e.msg)) - self._epoch_no = 0 - - (face_A,face_B) = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5']) - - try: - self.encoder.load_weights(str(self.model_dir / hdf['encoderH5'])) - self.decoder_A.load_weights(str(self.model_dir / face_A)) - self.decoder_B.load_weights(str(self.model_dir / face_B)) - logger.info('loaded model weights') - return True - except Exception as e: - logger.warning('Failed loading existing training data. Starting a fresh model: %s', self.model_dir) - return False - - def save_weights(self): - model_dir = str(self.model_dir) - for model in hdf.values(): - backup_file(model_dir, model) - self.encoder.save_weights(str(self.model_dir / hdf['encoderH5'])) - self.decoder_A.save_weights(str(self.model_dir / hdf['decoder_AH5'])) - self.decoder_B.save_weights(str(self.model_dir / hdf['decoder_BH5'])) - - logger.info('saved model weights') - - serializer = Serializer.get_serializer('json') - state_fn = ".".join([hdf['state'], serializer.ext]) - state_dir = str(self.model_dir / state_fn) - try: - with open(state_dir, 'wb') as fp: - state_json = serializer.marshal({ - 'epoch_no' : self.epoch_no - }) - fp.write(state_json.encode('utf-8')) - except IOError as e: - logger.error(e.strerror) - - @property - def epoch_no(self): - "Get current training epoch number" - return self._epoch_no diff --git a/plugins/model/Model_Original/Model.py b/plugins/model/Model_Original/Model.py deleted file mode 100644 index ce7648ebc9..0000000000 --- a/plugins/model/Model_Original/Model.py +++ /dev/null @@ -1,71 +0,0 @@ -# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs - -from keras.models import Model as KerasModel -from keras.layers import Input, Dense, Flatten, Reshape -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import Conv2D -from keras.optimizers import Adam - -from .AutoEncoder import AutoEncoder -from lib.PixelShuffler import PixelShuffler - -from keras.utils import multi_gpu_model - -IMAGE_SHAPE = (64, 64, 3) -ENCODER_DIM = 1024 - -class Model(AutoEncoder): - def initModel(self): - optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - x = Input(shape=IMAGE_SHAPE) - - self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x))) - self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x))) - - if self.gpus > 1: - self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus) - self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus) - - self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error') - self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error') - - def converter(self, swap): - autoencoder = self.autoencoder_B if not swap else self.autoencoder_A - return lambda img: autoencoder.predict(img) - - def conv(self, filters): - def block(x): - x = Conv2D(filters, kernel_size=5, strides=2, padding='same')(x) - x = LeakyReLU(0.1)(x) - return x - return block - - def upscale(self, filters): - def block(x): - x = Conv2D(filters * 4, kernel_size=3, padding='same')(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def Encoder(self): - input_ = Input(shape=IMAGE_SHAPE) - x = input_ - x = self.conv(128)(x) - x = self.conv(256)(x) - x = self.conv(512)(x) - x = self.conv(1024)(x) - x = Dense(ENCODER_DIM)(Flatten()(x)) - x = Dense(4 * 4 * 1024)(x) - x = Reshape((4, 4, 1024))(x) - x = self.upscale(512)(x) - return KerasModel(input_, x) - - def Decoder(self): - input_ = Input(shape=(8, 8, 512)) - x = input_ - x = self.upscale(256)(x) - x = self.upscale(128)(x) - x = self.upscale(64)(x) - x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) - return KerasModel(input_, x) diff --git a/plugins/model/Model_Original/Trainer.py b/plugins/model/Model_Original/Trainer.py deleted file mode 100644 index 17216b9ea5..0000000000 --- a/plugins/model/Model_Original/Trainer.py +++ /dev/null @@ -1,59 +0,0 @@ - -import time -import numpy -from lib.training_data import TrainingDataGenerator, stack_images - -class Trainer(): - random_transform_args = { - 'rotation_range': 10, - 'zoom_range': 0.05, - 'shift_range': 0.05, - 'random_flip': 0.4, - } - - def __init__(self, model, fn_A, fn_B, batch_size, *args): - self.batch_size = batch_size - self.model = model - - generator = TrainingDataGenerator(self.random_transform_args, 160) - self.images_A = generator.minibatchAB(fn_A, self.batch_size) - self.images_B = generator.minibatchAB(fn_B, self.batch_size) - - def train_one_step(self, iter, viewer): - epoch, warped_A, target_A = next(self.images_A) - epoch, warped_B, target_B = next(self.images_B) - - loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A) - loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B) - - self.model._epoch_no += 1 - - print("[{0}] [#{1:05d}] loss_A: {2:.5f}, loss_B: {3:.5f}".format(time.strftime("%H:%M:%S"), self.model.epoch_no, loss_A, loss_B), - end='\r') - - if viewer is not None: - viewer(self.show_sample(target_A[0:14], target_B[0:14]), "training") - - def show_sample(self, test_A, test_B): - figure_A = numpy.stack([ - test_A, - self.model.autoencoder_A.predict(test_A), - self.model.autoencoder_B.predict(test_A), - ], axis=1) - figure_B = numpy.stack([ - test_B, - self.model.autoencoder_B.predict(test_B), - self.model.autoencoder_A.predict(test_B), - ], axis=1) - - if test_A.shape[0] % 2 == 1: - figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ]) - figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ]) - - figure = numpy.concatenate([figure_A, figure_B], axis=0) - w = 4 - h = int( figure.shape[0] / w) - figure = figure.reshape((w, h) + figure.shape[1:]) - figure = stack_images(figure) - - return numpy.clip(figure * 255, 0, 255).astype('uint8') diff --git a/plugins/model/Model_Original/__init__.py b/plugins/model/Model_Original/__init__.py deleted file mode 100644 index 82d540fe2e..0000000000 --- a/plugins/model/Model_Original/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- - -__author__ = """Based on https://reddit.com/u/deepfakes/""" -__version__ = '0.1.0' - -from .Model import Model -from .Trainer import Trainer -from .AutoEncoder import AutoEncoder \ No newline at end of file diff --git a/plugins/model/Model_OriginalHighRes/Model.py b/plugins/model/Model_OriginalHighRes/Model.py deleted file mode 100644 index 804d948941..0000000000 --- a/plugins/model/Model_OriginalHighRes/Model.py +++ /dev/null @@ -1,312 +0,0 @@ -#!/usr/bin/python3 - -# Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs -# Based on https://github.com/iperov/OpenDeepFaceSwap for Decoder multiple res block chain -# Based on the https://github.com/shaoanlu/faceswap-GAN repo -# source : https://github.com/shaoanlu/faceswap-GAN/blob/master/FaceSwap_GAN_v2_sz128_train.ipynbtemp/faceswap_GAN_keras.ipynb - - -import enum -import logging -import os -import sys -import warnings -warnings.filterwarnings("ignore", category=FutureWarning) - -from keras.initializers import RandomNormal -from keras.layers import Input, Dense, Flatten, Reshape -from keras.layers import SeparableConv2D, add -from keras.layers.advanced_activations import LeakyReLU -from keras.layers.convolutional import Conv2D -from keras.layers.core import Activation -from keras.models import Model as KerasModel -from keras.optimizers import Adam -from keras.utils import multi_gpu_model - -from lib.PixelShuffler import PixelShuffler -import lib.Serializer -from lib.utils import backup_file - -from . import __version__ -from .instance_normalization import InstanceNormalization - - -if isinstance(__version__, (list, tuple)): - version_str = ".".join([str(n) for n in __version__[1:]]) -else: - version_str = __version__ - -logger = logging.getLogger(__name__) # pylint: disable=invalid-name -mswindows = sys.platform=="win32" - - -class EncoderType(enum.Enum): - ORIGINAL = "original" - SHAOANLU = "shaoanlu" - - -_kern_init = RandomNormal(0, 0.02) - - -def inst_norm(): - return InstanceNormalization() - - -ENCODER = EncoderType.ORIGINAL - - -hdf = {'encoderH5': 'encoder_{version_str}{ENCODER.value}.h5'.format(**vars()), - 'decoder_AH5': 'decoder_A_{version_str}{ENCODER.value}.h5'.format(**vars()), - 'decoder_BH5': 'decoder_B_{version_str}{ENCODER.value}.h5'.format(**vars())} - -class Model(): - - ENCODER_DIM = 1024 # dense layer size - IMAGE_SHAPE = 128, 128 # image shape - - assert [n for n in IMAGE_SHAPE if n>=16] - - IMAGE_WIDTH = max(IMAGE_SHAPE) - IMAGE_WIDTH = (IMAGE_WIDTH//16 + (1 if (IMAGE_WIDTH%16)>=8 else 0))*16 - IMAGE_SHAPE = IMAGE_WIDTH, IMAGE_WIDTH, len('BRG') # good to let ppl know what these are... - - - def __init__(self, model_dir, gpus, encoder_type=ENCODER): - - if mswindows: - from ctypes import cdll - mydll = cdll.LoadLibrary("user32.dll") - mydll.SetProcessDPIAware(True) - - self._encoder_type = encoder_type - - self.model_dir = model_dir - - # can't chnage gpu's when the model is initialized no point in making it r/w - self._gpus = gpus - - Encoder = getattr(self, "Encoder_{}".format(self._encoder_type.value)) - Decoder = getattr(self, "Decoder_{}".format(self._encoder_type.value)) - - self.encoder = Encoder() - self.decoder_A = Decoder() - self.decoder_B = Decoder() - - self.initModel() - - - def initModel(self): - optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - - x = Input(shape=self.IMAGE_SHAPE) - - self.autoencoder_A = KerasModel(x, self.decoder_A(self.encoder(x))) - self.autoencoder_B = KerasModel(x, self.decoder_B(self.encoder(x))) - - if self.gpus > 1: - self.autoencoder_A = multi_gpu_model( self.autoencoder_A , self.gpus) - self.autoencoder_B = multi_gpu_model( self.autoencoder_B , self.gpus) - - - self.autoencoder_A.compile(optimizer=optimizer, loss='mean_absolute_error') - self.autoencoder_B.compile(optimizer=optimizer, loss='mean_absolute_error') - - - def load(self, swapped): - model_dir = str(self.model_dir) - - from json import JSONDecodeError - face_A, face_B = (hdf['decoder_AH5'], hdf['decoder_BH5']) if not swapped else (hdf['decoder_BH5'], hdf['decoder_AH5']) - - state_dir = os.path.join(model_dir, 'state_{version_str}_{ENCODER.value}.json'.format(**globals())) - ser = lib.Serializer.get_serializer('json') - - try: - with open(state_dir, 'rb') as fp: - state = ser.unmarshal(fp.read().decode('utf-8')) - self._epoch_no = state['epoch_no'] - except IOError as e: - logger.warning('Error loading training info: %s', str(e.strerror)) - self._epoch_no = 0 - except JSONDecodeError as e: - logger.warning('Error loading training info: %s', str(e.msg)) - self._epoch_no = 0 - - try: - self.encoder.load_weights(os.path.join(model_dir, hdf['encoderH5'])) - self.decoder_A.load_weights(os.path.join(model_dir, face_A)) - self.decoder_B.load_weights(os.path.join(model_dir, face_B)) - logger.info('loaded model weights') - return True - except IOError as e: - logger.warning('Error loading training info: %s', str(e.strerror)) - except Exception as e: - logger.warning('Error loading training info: %s', str(e)) - - return False - - def converter(self, swap): - autoencoder = self.autoencoder_B if not swap else self.autoencoder_A - return autoencoder.predict - - - def conv(self, filters, kernel_size=5, strides=2, **kwargs): - def block(x): - x = Conv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x) - x = LeakyReLU(0.1)(x) - return x - return block - - def conv_sep(self, filters, kernel_size=5, strides=2, use_instance_norm=True, **kwargs): - def block(x): - x = SeparableConv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x) - x = Activation("relu")(x) - return x - return block - - def conv_inst_norm(self, filters, kernel_size=3, strides=2, use_instance_norm=True, **kwargs): - def block(x): - x = SeparableConv2D(filters, kernel_size=kernel_size, strides=strides, kernel_initializer=_kern_init, padding='same', **kwargs)(x) - if use_instance_norm: - x = inst_norm()(x) - x = Activation("relu")(x) - return x - return block - - def upscale(self, filters, **kwargs): - def block(x): - x = Conv2D(filters * 4, kernel_size=3, padding='same', - kernel_initializer=_kern_init)(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def upscale_inst_norm(self, filters, use_instance_norm=True, **kwargs): - def block(x): - x = Conv2D(filters*4, kernel_size=3, use_bias=False, - kernel_initializer=_kern_init, padding='same', **kwargs)(x) - if use_instance_norm: - x = inst_norm()(x) - x = LeakyReLU(0.1)(x) - x = PixelShuffler()(x) - return x - return block - - def Encoder_original(self, **kwargs): - impt = Input(shape=self.IMAGE_SHAPE) - - in_conv_filters = self.IMAGE_SHAPE[0] if self.IMAGE_SHAPE[0] <= 128 else 128 + (self.IMAGE_SHAPE[0]-128)//4 - - x = self.conv(in_conv_filters)(impt) - x = self.conv_sep(256)(x) - x = self.conv(512)(x) - x = self.conv_sep(1024)(x) - - dense_shape = self.IMAGE_SHAPE[0] // 16 - x = Dense(self.ENCODER_DIM, kernel_initializer=_kern_init)(Flatten()(x)) - x = Dense(dense_shape * dense_shape * 512, kernel_initializer=_kern_init)(x) - x = Reshape((dense_shape, dense_shape, 512))(x) - x = self.upscale(512)(x) - - return KerasModel(impt, x, **kwargs) - - - def Encoder_shaoanlu(self, **kwargs): - impt = Input(shape=self.IMAGE_SHAPE) - - in_conv_filters = self.IMAGE_SHAPE[0] if self.IMAGE_SHAPE[0] <= 128 else 128 + (self.IMAGE_SHAPE[0]-128)//4 - - x = Conv2D(in_conv_filters, kernel_size=5, use_bias=False, padding="same")(impt) - x = self.conv_inst_norm(in_conv_filters+32, use_instance_norm=False)(x) - x = self.conv_inst_norm(256)(x) - x = self.conv_inst_norm(512)(x) - x = self.conv_inst_norm(1024)(x) - - dense_shape = self.IMAGE_SHAPE[0] // 16 - x = Dense(self.ENCODER_DIM, kernel_initializer=_kern_init)(Flatten()(x)) - x = Dense(dense_shape * dense_shape * 768, kernel_initializer=_kern_init)(x) - x = Reshape((dense_shape, dense_shape, 768))(x) - x = self.upscale(512)(x) - - return KerasModel(impt, x, **kwargs) - - - def Decoder_original(self): - decoder_shape = self.IMAGE_SHAPE[0]//8 - inpt = Input(shape=(decoder_shape, decoder_shape, 512)) - - x = self.upscale(384)(inpt) - x = self.upscale(256-32)(x) - x = self.upscale(self.IMAGE_SHAPE[0])(x) - - x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) - - return KerasModel(inpt, x) - - - def Decoder_shaoanlu(self): - decoder_shape = self.IMAGE_SHAPE[0]//8 - inpt = Input(shape=(decoder_shape, decoder_shape, 512)) - - x = self.upscale_inst_norm(512)(inpt) - x = self.upscale_inst_norm(256)(x) - x = self.upscale_inst_norm(self.IMAGE_SHAPE[0])(x) - - x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) - - return KerasModel(inpt, x) - - - def save_weights(self): - model_dir = str(self.model_dir) - - try: - for model in hdf.values(): - backup_file(model_dir, model) - except NameError: - logger.error('backup functionality not available\n') - - state_dir = os.path.join(model_dir, 'state_{version_str}_{ENCODER.value}.json'.format(**globals())) - ser = lib.Serializer.get_serializer('json') - try: - with open(state_dir, 'wb') as fp: - state_json = ser.marshal({ - 'epoch_no' : self._epoch_no - }) - fp.write(state_json.encode('utf-8')) - except IOError as e: - logger.error(e.strerror) - - logger.info('saving model weights') - - from concurrent.futures import ThreadPoolExecutor, as_completed - - with ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(getattr(self, mdl_name.rstrip('H5')).save_weights, str(self.model_dir / mdl_H5_fn)) for mdl_name, mdl_H5_fn in hdf.items()] - for future in as_completed(futures): - future.result() - print('.', end='', flush=True) - - logger.info('done') - - - @property - def gpus(self): - return self._gpus - - @property - def model_name(self): - try: - return self._model_name - except AttributeError: - import inspect - self._model_name = os.path.dirname(inspect.getmodule(self).__file__).rsplit("_", 1)[1] - return self._model_name - - - def __str__(self): - return "<{}: ver={}, dense_dim={}, img_size={}>".format(self.model_name, - version_str, - self.ENCODER_DIM, - "x".join([str(n) for n in self.IMAGE_SHAPE[:2]])) diff --git a/plugins/model/Model_OriginalHighRes/Trainer.py b/plugins/model/Model_OriginalHighRes/Trainer.py deleted file mode 100644 index ea79d5b72b..0000000000 --- a/plugins/model/Model_OriginalHighRes/Trainer.py +++ /dev/null @@ -1,84 +0,0 @@ -import time -import numpy - -from lib.training_data import TrainingDataGenerator, stack_images - - -TRANSFORM_PRC = 115. - - -class Trainer(): - - _random_transform_args = { - 'rotation_range': 10 * (TRANSFORM_PRC * .01), - 'zoom_range': 0.05 * (TRANSFORM_PRC * .01), - 'shift_range': 0.05 * (TRANSFORM_PRC * .01), - 'random_flip': 0.4 * (TRANSFORM_PRC * .01), - } - - def __init__(self, model, fn_A, fn_B, batch_size, *args): - self.batch_size = batch_size - self.model = model - from timeit import default_timer as clock - self._clock = clock - - generator = TrainingDataGenerator(self.random_transform_args, 160, 5, zoom=self.model.IMAGE_SHAPE[0]//64) - - self.images_A = generator.minibatchAB(fn_A, self.batch_size) - self.images_B = generator.minibatchAB(fn_B, self.batch_size) - - self.generator = generator - - - def train_one_step(self, iter_no, viewer): - when = self._clock() - _, warped_A, target_A = next(self.images_A) - _, warped_B, target_B = next(self.images_B) - - loss_A = self.model.autoencoder_A.train_on_batch(warped_A, target_A) - loss_B = self.model.autoencoder_B.train_on_batch(warped_B, target_B) - - self.model._epoch_no += 1 - - if isinstance(loss_A, (list, tuple)): - print("[{0}] [#{1:05d}] [{2:.3f}s] loss_A: {3:.5f}, loss_B: {4:.5f}".format( - time.strftime("%H:%M:%S"), self.model._epoch_no, self._clock()-when, loss_A[1], loss_B[1]), - end='\r') - else: - print("[{0}] [#{1:05d}] [{2:.3f}s] loss_A: {3:.5f}, loss_B: {4:.5f}".format( - time.strftime("%H:%M:%S"), self.model._epoch_no, self._clock()-when, loss_A, loss_B), - end='\r') - - if viewer is not None: - viewer(self.show_sample(target_A[0:8], target_B[0:8]), "training using {}, bs={}".format(self.model, self.batch_size)) - - - def show_sample(self, test_A, test_B): - figure_A = numpy.stack([ - test_A, - self.model.autoencoder_A.predict(test_A), - self.model.autoencoder_B.predict(test_A), - ], axis=1) - - figure_B = numpy.stack([ - test_B, - self.model.autoencoder_B.predict(test_B), - self.model.autoencoder_A.predict(test_B), - ], axis=1) - - if (test_A.shape[0] % 2)!=0: - figure_A = numpy.concatenate ([figure_A, numpy.expand_dims(figure_A[0],0) ]) - figure_B = numpy.concatenate ([figure_B, numpy.expand_dims(figure_B[0],0) ]) - - figure = numpy.concatenate([figure_A, figure_B], axis=0) - w = 4 - h = int( figure.shape[0] / w) - figure = figure.reshape((w, h) + figure.shape[1:]) - figure = stack_images(figure) - - return numpy.clip(figure * 255, 0, 255).astype('uint8') - - - @property - def random_transform_args(self): - return self._random_transform_args diff --git a/plugins/model/Model_OriginalHighRes/__init__.py b/plugins/model/Model_OriginalHighRes/__init__.py deleted file mode 100644 index e6601b4a8b..0000000000 --- a/plugins/model/Model_OriginalHighRes/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- - -__author__ = """Based on https://reddit.com/u/deepfakes/""" - -from ._version import __version__ -from .Model import Model -from .Trainer import Trainer - diff --git a/plugins/model/Model_OriginalHighRes/_version.py b/plugins/model/Model_OriginalHighRes/_version.py deleted file mode 100644 index 96c4bb6956..0000000000 --- a/plugins/model/Model_OriginalHighRes/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = 0, 2, 7 \ No newline at end of file diff --git a/plugins/model/Model_OriginalHighRes/instance_normalization.py b/plugins/model/Model_OriginalHighRes/instance_normalization.py deleted file mode 100644 index 9c0616af27..0000000000 --- a/plugins/model/Model_OriginalHighRes/instance_normalization.py +++ /dev/null @@ -1,145 +0,0 @@ -from keras.engine import Layer, InputSpec -from keras import initializers, regularizers, constraints -from keras import backend as K -from keras.utils.generic_utils import get_custom_objects - -import numpy as np - - -class InstanceNormalization(Layer): - """Instance normalization layer (Lei Ba et al, 2016, Ulyanov et al., 2016). - Normalize the activations of the previous layer at each step, - i.e. applies a transformation that maintains the mean activation - close to 0 and the activation standard deviation close to 1. - # Arguments - axis: Integer, the axis that should be normalized - (typically the features axis). - For instance, after a `Conv2D` layer with - `data_format="channels_first"`, - set `axis=1` in `InstanceNormalization`. - Setting `axis=None` will normalize all values in each instance of the batch. - Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. - If False, `beta` is ignored. - scale: If True, multiply by `gamma`. - If False, `gamma` is not used. - When the next layer is linear (also e.g. `nn.relu`), - this can be disabled since the scaling - will be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - beta_constraint: Optional constraint for the beta weight. - gamma_constraint: Optional constraint for the gamma weight. - # Input shape - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - # Output shape - Same shape as input. - # References - - [Layer Normalization](https://arxiv.org/abs/1607.06450) - - [Instance Normalization: The Missing Ingredient for Fast Stylization](https://arxiv.org/abs/1607.08022) - """ - def __init__(self, - axis=None, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer='zeros', - gamma_initializer='ones', - beta_regularizer=None, - gamma_regularizer=None, - beta_constraint=None, - gamma_constraint=None, - **kwargs): - super(InstanceNormalization, self).__init__(**kwargs) - self.supports_masking = True - self.axis = axis - self.epsilon = epsilon - self.center = center - self.scale = scale - self.beta_initializer = initializers.get(beta_initializer) - self.gamma_initializer = initializers.get(gamma_initializer) - self.beta_regularizer = regularizers.get(beta_regularizer) - self.gamma_regularizer = regularizers.get(gamma_regularizer) - self.beta_constraint = constraints.get(beta_constraint) - self.gamma_constraint = constraints.get(gamma_constraint) - - def build(self, input_shape): - ndim = len(input_shape) - if self.axis == 0: - raise ValueError('Axis cannot be zero') - - if (self.axis is not None) and (ndim == 2): - raise ValueError('Cannot specify axis for rank 1 tensor') - - self.input_spec = InputSpec(ndim=ndim) - - if self.axis is None: - shape = (1,) - else: - shape = (input_shape[self.axis],) - - if self.scale: - self.gamma = self.add_weight(shape=shape, - name='gamma', - initializer=self.gamma_initializer, - regularizer=self.gamma_regularizer, - constraint=self.gamma_constraint) - else: - self.gamma = None - if self.center: - self.beta = self.add_weight(shape=shape, - name='beta', - initializer=self.beta_initializer, - regularizer=self.beta_regularizer, - constraint=self.beta_constraint) - else: - self.beta = None - self.built = True - - def call(self, inputs, training=None): - input_shape = K.int_shape(inputs) - reduction_axes = list(range(0, len(input_shape))) - - if (self.axis is not None): - del reduction_axes[self.axis] - - del reduction_axes[0] - - mean = K.mean(inputs, reduction_axes, keepdims=True) - stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon - normed = (inputs - mean) / stddev - - broadcast_shape = [1] * len(input_shape) - if self.axis is not None: - broadcast_shape[self.axis] = input_shape[self.axis] - - if self.scale: - broadcast_gamma = K.reshape(self.gamma, broadcast_shape) - normed = normed * broadcast_gamma - if self.center: - broadcast_beta = K.reshape(self.beta, broadcast_shape) - normed = normed + broadcast_beta - return normed - - def get_config(self): - config = { - 'axis': self.axis, - 'epsilon': self.epsilon, - 'center': self.center, - 'scale': self.scale, - 'beta_initializer': initializers.serialize(self.beta_initializer), - 'gamma_initializer': initializers.serialize(self.gamma_initializer), - 'beta_regularizer': regularizers.serialize(self.beta_regularizer), - 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), - 'beta_constraint': constraints.serialize(self.beta_constraint), - 'gamma_constraint': constraints.serialize(self.gamma_constraint) - } - base_config = super(InstanceNormalization, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - -get_custom_objects().update({'InstanceNormalization': InstanceNormalization}) diff --git a/plugins/plugin_loader.py b/plugins/plugin_loader.py index 188d92275c..3506cb84cc 100644 --- a/plugins/plugin_loader.py +++ b/plugins/plugin_loader.py @@ -23,21 +23,22 @@ def get_aligner(name): @staticmethod def get_converter(name): """ Return requested converter plugin """ - return PluginLoader._import("Convert", "Convert_{0}".format(name)) + return PluginLoader._import("convert", name) @staticmethod def get_model(name): """ Return requested model plugin """ - return PluginLoader._import("Model", "Model_{0}".format(name)) + return PluginLoader._import("train.model", name) @staticmethod def get_trainer(name): """ Return requested trainer plugin """ - return PluginLoader._import("Trainer", "Model_{0}".format(name)) + return PluginLoader._import("train.trainer", name) @staticmethod def _import(attr, name): """ Import the plugin's module """ + name = name.replace("-", "_") ttl = attr.split(".")[-1].title() logger.info("Loading %s from %s plugin...", ttl, name.title()) attr = "model" if attr == "Trainer" else attr.lower() @@ -48,13 +49,23 @@ def _import(attr, name): @staticmethod def get_available_models(): """ Return a list of available models """ - models = () - modelpath = os.path.join(os.path.dirname(__file__), "model") - for modeldir in next(os.walk(modelpath))[1]: - if modeldir[0:6].lower() == 'model_': - models += (modeldir[6:],) + modelpath = os.path.join(os.path.dirname(__file__), "train", "model") + models = sorted(item.name.replace(".py", "").replace("_", "-") + for item in os.scandir(modelpath) + if not item.name.startswith("_") + and item.name.endswith(".py")) return models + @staticmethod + def get_available_converters(): + """ Return a list of available converters """ + converter_path = os.path.join(os.path.dirname(__file__), "convert") + converters = sorted(item.name.replace(".py", "").replace("_", "-") + for item in os.scandir(converter_path) + if not item.name.startswith("_") + and item.name.endswith(".py")) + return converters + @staticmethod def get_available_extractors(extractor_type): """ Return a list of available models """ @@ -72,4 +83,4 @@ def get_available_extractors(extractor_type): def get_default_model(): """ Return the default model """ models = PluginLoader.get_available_models() - return 'Original' if 'Original' in models else models[0] + return 'original' if 'original' in models else models[0] diff --git a/plugins/train/_config.py b/plugins/train/_config.py new file mode 100644 index 0000000000..994c9cd067 --- /dev/null +++ b/plugins/train/_config.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +""" Default configurations for models """ + +import logging + +from lib.config import FaceswapConfig + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +MASK_TYPES = ["none", "dfaker", "dfl_full"] +MASK_INFO = "The mask to be used for training. Select none to not use a mask" +COVERAGE_INFO = ("How much of the extracted image to train on. Generally the model is optimized\n" + "to the default value. Sensible values to use are:" + "\n\t62.5%% spans from eyebrow to eyebrow." + "\n\t75.0%% spans from temple to temple." + "\n\t87.5%% spans from ear to ear." + "\n\t100.0%% is a mugshot.") + + +class Config(FaceswapConfig): + """ Config File for Models """ + + def set_defaults(self): + """ Set the default values for config """ + logger.debug("Setting defaults") + # << GLOBAL OPTIONS >> # + section = "global" + self.add_section(title=section, + info="Options that apply to all models") + self.add_item( + section=section, title="icnr_init", datatype=bool, default=False, + info="Use ICNR Kernel Initializer for upscaling.\nThis can help reduce the " + "'checkerboard effect' when upscaling the image.") + self.add_item( + section=section, title="subpixel_upscaling", datatype=bool, default=False, + info="Use subpixel upscaling rather than pixel shuffler.\n" + "Might increase speed at cost of VRAM") + self.add_item( + section=section, title="reflect_padding", datatype=bool, default=False, + info="Use reflect padding rather than zero padding.") + self.add_item( + section=section, title="dssim_mask_loss", datatype=bool, default=True, + info="If using a mask, Use DSSIM loss for Mask training rather than Mean Absolute " + "Error\nMay increase overall quality.") + self.add_item( + section=section, title="penalized_mask_loss", datatype=bool, default=True, + info="If using a mask, Use Penalized loss for Mask training. Can stack with DSSIM.\n" + "May increase overall quality.") + + # << DFAKER OPTIONS >> # + section = "model.dfaker" + self.add_section(title=section, + info="Dfaker Model (Adapted from https://github.com/dfaker/df)") + self.add_item( + section=section, title="mask_type", datatype=str, default="dfaker", + choices=MASK_TYPES, info=MASK_INFO) + self.add_item( + section=section, title="coverage", datatype=float, default=100.0, rounding=1, + min_max=(62.5, 100.0), info=COVERAGE_INFO) + + # << DFL MODEL OPTIONS >> # + section = "model.dfl_h128" + self.add_section(title=section, + info="DFL H128 Model (Adapted from " + "https://github.com/iperov/DeepFaceLab)") + self.add_item( + section=section, title="lowmem", datatype=bool, default=False, + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " + "with a changed lowmem mode are not compatible with each other.") + self.add_item( + section=section, title="mask_type", datatype=str, default="dfl_full", + choices=MASK_TYPES, info=MASK_INFO) + self.add_item( + section=section, title="coverage", datatype=float, default=62.5, rounding=1, + min_max=(62.5, 100.0), info=COVERAGE_INFO) + + # << IAE MODEL OPTIONS >> # + section = "model.iae" + self.add_section(title=section, + info="Intermediate Auto Encoder. Based on Original Model, uses " + "intermediate layers to try to better get details") + self.add_item( + section=section, title="dssim_loss", datatype=bool, default=False, + info="Use DSSIM for Loss rather than Mean Absolute Error\n" + "May increase overall quality.") + self.add_item( + section=section, title="mask_type", datatype=str, default="none", + choices=MASK_TYPES, info=MASK_INFO) + self.add_item( + section=section, title="coverage", datatype=float, default=62.5, rounding=1, + min_max=(62.5, 100.0), info=COVERAGE_INFO) + + # << ORIGINAL MODEL OPTIONS >> # + section = "model.original" + self.add_section(title=section, + info="Original Faceswap Model") + self.add_item( + section=section, title="lowmem", datatype=bool, default=False, + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " + "with a changed lowmem mode are not compatible with each other.") + self.add_item( + section=section, title="dssim_loss", datatype=bool, default=False, + info="Use DSSIM for Loss rather than Mean Absolute Error\n" + "May increase overall quality.") + self.add_item( + section=section, title="mask_type", datatype=str, default="none", + choices=MASK_TYPES, info=MASK_INFO) + self.add_item( + section=section, title="coverage", datatype=float, default=62.5, rounding=1, + min_max=(62.5, 100.0), info=COVERAGE_INFO) + + # << UNBALANCED MODEL OPTIONS >> # + section = "model.unbalanced" + self.add_section(title=section, + info="An unbalanced model with adjustable input size options.\n" + "This is an unbalanced model so b>a swaps may not work well") + self.add_item( + section=section, title="lowmem", datatype=bool, default=False, + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " + "with a changed lowmem mode are not compatible with each other. NB: lowmem will " + "override cutom nodes and complexity settings.") + self.add_item( + section=section, title="dssim_loss", datatype=bool, default=False, + info="Use DSSIM for Loss rather than Mean Absolute Error\n" + "May increase overall quality.") + self.add_item( + section=section, title="mask_type", datatype=str, default="none", + choices=MASK_TYPES, info=MASK_INFO) + self.add_item( + section=section, title="nodes", datatype=int, default=1024, rounding=64, + min_max=(512, 4096), + info="Number of nodes for decoder. Don't change this unless you " + "know what you are doing!") + self.add_item( + section=section, title="complexity_encoder", datatype=int, default=128, + rounding=16, min_max=(64, 1024), + info="Encoder Convolution Layer Complexity. sensible ranges: " + "128 to 160") + self.add_item( + section=section, title="complexity_decoder_a", datatype=int, default=384, + rounding=16, min_max=(64, 1024), + info="Decoder A Complexity.") + self.add_item( + section=section, title="complexity_decoder_b", datatype=int, default=512, + rounding=16, min_max=(64, 1024), + info="Decoder B Complexity.") + self.add_item( + section=section, title="input_size", datatype=int, default=128, + rounding=64, min_max=(64, 512), + info="Resolution (in pixels) of the image to train on.\n" + "BE AWARE Larger resolution will dramatically increase" + "VRAM requirements.\n" + "Make sure your resolution is divisible by 64 (e.g. 64, 128, 256 etc.).\n" + "NB: Your faceset must be at least 1.6x larger than your required input size.\n" + " (e.g. 160 is the maximum input size for a 256x256 faceset)") + self.add_item( + section=section, title="coverage", datatype=float, default=62.5, rounding=1, + min_max=(62.5, 100.0), info=COVERAGE_INFO) + + # << VILLAIN MODEL OPTIONS >> # + section = "model.villain" + self.add_section(title=section, + info="A Higher resolution version of the Original " + "Model by VillainGuy.\n" + "Extremely VRAM heavy. Full model requires 9GB+ for batchsize 16") + self.add_item( + section=section, title="lowmem", datatype=bool, default=False, + info="Lower memory mode. Set to 'True' if having issues with VRAM useage.\nNB: Models " + "with a changed lowmem mode are not compatible with each other.") + self.add_item( + section=section, title="dssim_loss", datatype=bool, default=False, + info="Use DSSIM for Loss rather than Mean Absolute Error\n" + "May increase overall quality.") + self.add_item( + section=section, title="mask_type", datatype=str, default="none", + choices=["none", "dfaker", "dfl_full"], + info="The mask to be used for training. Select none to not use a mask") + self.add_item( + section=section, title="coverage", datatype=float, default=62.5, rounding=1, + min_max=(62.5, 100.0), info=COVERAGE_INFO) diff --git a/plugins/train/model/__init__.py b/plugins/train/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/train/model/_base.py b/plugins/train/model/_base.py new file mode 100644 index 0000000000..07ae0f9248 --- /dev/null +++ b/plugins/train/model/_base.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +""" Base class for Models. ALL Models should at least inherit from this class + + When inheriting model_data should be a list of NNMeta objects. + See the class for details. +""" +import logging +import os +import sys +import time + +from json import JSONDecodeError + +from keras import losses +from keras.models import load_model +from keras.optimizers import Adam +from keras.utils import get_custom_objects, multi_gpu_model + +from lib import Serializer +from lib.model.losses import DSSIMObjective, PenalizedLoss +from lib.model.nn_blocks import NNBlocks +from lib.multithreading import MultiThread +from plugins.train._config import Config + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name +_CONFIG = None + + +class ModelBase(): + """ Base class that all models should inherit from """ + def __init__(self, + model_dir, + gpus, + no_logs=False, + warp_to_landmarks=False, + no_flip=False, + training_image_size=256, + alignments_paths=None, + preview_scale=100, + input_shape=None, + encoder_dim=None, + trainer="original", + predict=False): + logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, " + "training_image_size, %s, alignments_paths: %s, preview_scale: %s, " + "input_shape: %s, encoder_dim: %s)", self.__class__.__name__, model_dir, gpus, + training_image_size, alignments_paths, preview_scale, input_shape, + encoder_dim) + self.predict = predict + self.model_dir = model_dir + self.gpus = gpus + self.blocks = NNBlocks(use_subpixel=self.config["subpixel_upscaling"], + use_icnr_init=self.config["icnr_init"], + use_reflect_padding=self.config["reflect_padding"]) + self.input_shape = input_shape + self.output_shape = None # set after model is compiled + self.encoder_dim = encoder_dim + self.trainer = trainer + + self.state = State(self.model_dir, self.name, no_logs, training_image_size) + self.load_state_info() + + self.networks = dict() # Networks for the model + self.predictors = dict() # Predictors for model + self.history = dict() # Loss history per save iteration) + + # Training information specific to the model should be placed in this + # dict for reference by the trainer. + self.training_opts = {"alignments": alignments_paths, + "preview_scaling": preview_scale / 100, + "warp_to_landmarks": warp_to_landmarks, + "no_flip": no_flip} + + self.build() + self.set_training_data() + logger.debug("Initialized ModelBase (%s)", self.__class__.__name__) + + @property + def config(self): + """ Return config dict for current plugin """ + global _CONFIG # pylint: disable=global-statement + if not _CONFIG: + model_name = ".".join(self.__module__.split(".")[-2:]) + logger.debug("Loading config for: %s", model_name) + _CONFIG = Config(model_name).config_dict + return _CONFIG + + @property + def name(self): + """ Set the model name based on the subclass """ + basename = os.path.basename(sys.modules[self.__module__].__file__) + retval = os.path.splitext(basename)[0].lower() + logger.debug("model name: '%s'", retval) + return retval + + def set_training_data(self): + """ Override to set model specific training data. + + super() this method for defaults otherwise be sure to add """ + logger.debug("Setting training data") + self.training_opts["training_size"] = self.state.training_size + self.training_opts["no_logs"] = self.state.current_session["no_logs"] + self.training_opts["mask_type"] = self.config.get("mask_type", None) + self.training_opts["coverage_ratio"] = self.config.get("coverage", 62.5) / 100 + self.training_opts["preview_images"] = 14 + logger.debug("Set training data: %s", self.training_opts) + + def build(self): + """ Build the model. Override for custom build methods """ + self.add_networks() + self.load_models(swapped=False) + self.build_autoencoders() + self.log_summary() + self.compile_predictors() + + def build_autoencoders(self): + """ Override for Model Specific autoencoder builds + + NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names + are expected: + face (the input for image) + mask (the input for mask if it is used) + + """ + raise NotImplementedError + + def add_networks(self): + """ Override to add neural networks """ + raise NotImplementedError + + def load_state_info(self): + """ Load the input shape from state file if it exists """ + logger.debug("Loading Input Shape from State file") + if not self.state.inputs: + logger.debug("No input shapes saved. Using model config") + return + if not self.state.face_shapes: + logger.warning("Input shapes stored in State file, but no matches for 'face'." + "Using model config") + return + input_shape = self.state.face_shapes[0] + logger.debug("Setting input shape from state file: %s", input_shape) + self.input_shape = input_shape + + def add_network(self, network_type, side, network): + """ Add a NNMeta object """ + logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network) + filename = "{}_{}".format(self.name, network_type.lower()) + name = network_type.lower() + if side: + side = side.lower() + filename += "_{}".format(side.upper()) + name += "_{}".format(side) + filename += ".h5" + logger.debug("name: '%s', filename: '%s'", name, filename) + self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network) + + def add_predictor(self, side, model): + """ Add a predictor to the predictors dictionary """ + logger.debug("Adding predictor: (side: '%s', model: %s)", side, model) + if self.gpus > 1: + logger.debug("Converting to multi-gpu: side %s", side) + model = multi_gpu_model(model, self.gpus) + self.predictors[side] = model + if not self.state.inputs: + self.store_input_shapes(model) + if not self.output_shape: + self.set_output_shape(model) + + def store_input_shapes(self, model): + """ Store the input and output shapes to state """ + logger.debug("Adding input shapes to state for model") + inputs = {tensor.name: tensor.get_shape().as_list()[-3:] for tensor in model.inputs} + if not any(inp for inp in inputs.keys() if inp.startswith("face")): + raise ValueError("No input named 'face' was found. Check your input naming. " + "Current input names: {}".format(inputs)) + self.state.inputs = inputs + logger.debug("Added input shapes: %s", self.state.inputs) + + def set_output_shape(self, model): + """ Set the output shape for use in training and convert """ + logger.debug("Setting output shape") + out = [tensor.get_shape().as_list()[-3:] for tensor in model.outputs] + if not out: + raise ValueError("No outputs found! Check your model.") + self.output_shape = tuple(out[0]) + logger.debug("Added output shape: %s", self.output_shape) + + def compile_predictors(self): + """ Compile the predictors """ + logger.debug("Compiling Predictors") + optimizer = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0) + + for side, model in self.predictors.items(): + loss_names = ["loss"] + loss_funcs = [self.loss_function(side)] + mask = [inp for inp in model.inputs if inp.name.startswith("mask")] + if mask: + loss_names.insert(0, "mask_loss") + loss_funcs.insert(0, self.mask_loss_function(mask[0], side)) + model.compile(optimizer=optimizer, loss=loss_funcs) + + if len(loss_names) > 1: + loss_names.insert(0, "total_loss") + self.state.add_session_loss_names(side, loss_names) + self.history[side] = list() + logger.debug("Compiled Predictors. Losses: %s", loss_names) + + def loss_function(self, side): + """ Set the loss function """ + if self.config.get("dssim_loss", False): + if side == "a" and not self.predict: + logger.verbose("Using DSSIM Loss") + loss_func = DSSIMObjective() + else: + if side == "a" and not self.predict: + logger.verbose("Using Mean Absolute Error Loss") + loss_func = losses.mean_absolute_error + logger.debug(loss_func) + return loss_func + + def mask_loss_function(self, mask, side): + """ Set the loss function for masks + Side is input so we only log once """ + if self.config.get("dssim_mask_loss", False): + if side == "a" and not self.predict: + logger.verbose("Using DSSIM Loss for mask") + mask_loss_func = DSSIMObjective() + else: + if side == "a" and not self.predict: + logger.verbose("Using Mean Absolute Error Loss for mask") + mask_loss_func = losses.mean_absolute_error + + if self.config.get("penalized_mask_loss", False): + if side == "a" and not self.predict: + logger.verbose("Using Penalized Loss for mask") + mask_loss_func = PenalizedLoss(mask, mask_loss_func) + logger.debug(mask_loss_func) + return mask_loss_func + + def converter(self, swap): + """ Converter for autoencoder models """ + logger.debug("Getting Converter: (swap: %s)", swap) + if swap: + retval = self.predictors["a"].predict + else: + retval = self.predictors["b"].predict + logger.debug("Got Converter: %s", retval) + return retval + + @property + def iterations(self): + "Get current training iteration number" + return self.state.iterations + + def map_models(self, swapped): + """ Map the models for A/B side for swapping """ + logger.debug("Map models: (swapped: %s)", swapped) + models_map = {"a": dict(), "b": dict()} + sides = ("a", "b") if not swapped else ("b", "a") + for network in self.networks.values(): + if network.side == sides[0]: + models_map["a"][network.type] = network.filename + if network.side == sides[1]: + models_map["b"][network.type] = network.filename + logger.debug("Mapped models: (models_map: %s)", models_map) + return models_map + + def log_summary(self): + """ Verbose log the model summaries """ + if self.predict: + return + for side in sorted(list(self.predictors.keys())): + logger.verbose("[%s %s Summary]:", self.name.title(), side.upper()) + self.predictors[side].summary(print_fn=lambda x: logger.verbose("R|%s", x)) + for name, nnmeta in self.networks.items(): + if nnmeta.side is not None and nnmeta.side != side: + continue + logger.verbose("%s:", name.title()) + nnmeta.network.summary(print_fn=lambda x: logger.verbose("R|%s", x)) + + def load_models(self, swapped): + """ Load models from file """ + logger.debug("Load model: (swapped: %s)", swapped) + model_mapping = self.map_models(swapped) + for network in self.networks.values(): + if not network.side: + is_loaded = network.load(predict=self.predict) + else: + is_loaded = network.load(fullpath=model_mapping[network.side][network.type], + predict=self.predict) + if not is_loaded: + break + if is_loaded: + logger.info("Loaded model from disk: '%s'", self.model_dir) + return is_loaded + + def save_models(self): + """ Backup and save the models """ + logger.debug("Backing up and saving models") + should_backup = self.get_save_averages() + save_threads = list() + for network in self.networks.values(): + name = "save_{}".format(network.name) + save_threads.append(MultiThread(network.save, name=name, should_backup=should_backup)) + save_threads.append(MultiThread(self.state.save, + name="save_state", should_backup=should_backup)) + for thread in save_threads: + thread.start() + for thread in save_threads: + if thread.has_error: + logger.error(thread.errors[0]) + thread.join() + # Put in a line break to avoid jumbled console + print("\n") + logger.info("saved models") + + def get_save_averages(self): + """ Return the loss averages since last save and reset historical losses + + This protects against model corruption by only backing up the model + if any of the loss values have fallen. + TODO This is not a perfect system. If the model corrupts on save_iteration - 1 + then model may still backup + """ + logger.debug("Getting Average loss since last save") + avgs = dict() + backup = True + + for side, loss in self.history.items(): + if not loss: + backup = False + break + + avgs[side] = sum(loss) / len(loss) + self.history[side] = list() # Reset historical loss + + if not self.state.lowest_avg_loss.get(side, None): + logger.debug("Setting initial save iteration loss average for '%s': %s", + side, avgs[side]) + self.state.lowest_avg_loss[side] = avgs[side] + continue + + if backup: + # Only run this if backup is true. All losses must have dropped for a valid backup + backup = self.check_loss_drop(side, avgs[side]) + + logger.debug("Lowest historical save iteration loss average: %s", + self.state.lowest_avg_loss) + logger.debug("Average loss since last save: %s", avgs) + + if backup: # Update lowest loss values to the state + for side, avg_loss in avgs.items(): + logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss) + self.state.lowest_avg_loss[side] = avg_loss + + logger.debug("Backing up: %s", backup) + return backup + + def check_loss_drop(self, side, avg): + """ Check whether total loss has dropped since lowest loss """ + if avg < self.state.lowest_avg_loss[side]: + logger.debug("Loss for '%s' has dropped", side) + return True + logger.debug("Loss for '%s' has not dropped", side) + return False + + +class NNMeta(): + """ Class to hold a neural network and it's meta data + + filename: The full path and filename of the model file for this network. + type: The type of network. For networks that can be swapped + The type should be identical for the corresponding + A and B networks, and should be unique for every A/B pair. + Otherwise the type should be completely unique. + side: A, B or None. Used to identify which networks can + be swapped. + network: Define network to this. + """ + + def __init__(self, filename, network_type, side, network): + logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', " + "network: %s", self.__class__.__name__, filename, network_type, + side, network) + self.filename = filename + self.type = network_type.lower() + self.side = side + self.name = self.set_name() + self.network = network + self.network.name = self.name + logger.debug("Initialized %s", self.__class__.__name__) + + def set_name(self): + """ Set the network name """ + name = self.type + if self.side: + name += "_{}".format(self.side) + return name + + def load(self, fullpath=None, predict=False): + """ Load model """ + fullpath = fullpath if fullpath else self.filename + logger.debug("Loading model: '%s'", fullpath) + try: + network = load_model(self.filename, custom_objects=get_custom_objects()) + except ValueError as err: + if str(err).lower().startswith("cannot create group in read only mode"): + self.convert_legacy_weights() + return True + if predict: + raise ValueError("Unable to load training data. Error: {}".format(str(err))) + logger.warning("Failed loading existing training data. Generating new models") + logger.debug("Exception: %s", str(err)) + return False + except OSError as err: # pylint: disable=broad-except + if predict: + raise ValueError("Unable to load training data. Error: {}".format(str(err))) + logger.warning("Failed loading existing training data. Generating new models") + logger.debug("Exception: %s", str(err)) + return False + self.network = network # Update network with saved model + self.network.name = self.type + return True + + def save(self, fullpath=None, should_backup=False): + """ Save model """ + fullpath = fullpath if fullpath else self.filename + if should_backup: + self.backup(fullpath=fullpath) + logger.debug("Saving model: '%s'", fullpath) + self.network.save(fullpath) + + def backup(self, fullpath=None): + """ Backup Model """ + origfile = fullpath if fullpath else self.filename + backupfile = origfile + ".bk" + logger.debug("Backing up: '%s' to '%s'", origfile, backupfile) + if os.path.exists(backupfile): + os.remove(backupfile) + if os.path.exists(origfile): + os.rename(origfile, backupfile) + + def convert_legacy_weights(self): + """ Convert legacy weights files to hold the model topology """ + logger.info("Adding model topology to legacy weights file: '%s'", self.filename) + self.network.load_weights(self.filename) + self.save(should_backup=False) + self.network.name = self.type + + +class State(): + """ Class to hold the model's current state and autoencoder structure """ + def __init__(self, model_dir, model_name, no_logs, training_image_size): + logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', no_logs: %s, " + "training_image_size: '%s'", self.__class__.__name__, model_dir, + model_name, no_logs, training_image_size) + self.serializer = Serializer.get_serializer("json") + filename = "{}_state.{}".format(model_name, self.serializer.ext) + self.filename = str(model_dir / filename) + self.iterations = 0 + self.session_iterations = 0 + self.training_size = training_image_size + self.sessions = dict() + self.lowest_avg_loss = dict() + self.inputs = dict() + self.config = dict() + self.load() + self.session_id = self.new_session_id() + self.create_new_session(no_logs) + logger.debug("Initialized %s:", self.__class__.__name__) + + @property + def face_shapes(self): + """ Return a list of stored face shape inputs """ + return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")] + + @property + def mask_shapes(self): + """ Return a list of stored mask shape inputs """ + return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")] + + @property + def loss_names(self): + """ Return the loss names for this session """ + return self.sessions[self.session_id]["loss_names"] + + @property + def current_session(self): + """ Return the current session dict """ + return self.sessions[self.session_id] + + def new_session_id(self): + """ Return new session_id """ + if not self.sessions: + session_id = 1 + else: + session_id = max(int(key) for key in self.sessions.keys()) + 1 + logger.debug(session_id) + return session_id + + def create_new_session(self, no_logs): + """ Create a new session """ + logger.debug("Creating new session. id: %s", self.session_id) + self.sessions[self.session_id] = {"timestamp": time.time(), + "no_logs": no_logs, + "loss_names": dict(), + "batchsize": 0, + "iterations": 0} + + def add_session_loss_names(self, side, loss_names): + """ Add the session loss names to the sessions dictionary """ + logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names) + self.sessions[self.session_id]["loss_names"][side] = loss_names + + def add_session_batchsize(self, batchsize): + """ Add the session batchsize to the sessions dictionary """ + logger.debug("Adding session batchsize: %s", batchsize) + self.sessions[self.session_id]["batchsize"] = batchsize + + def increment_iterations(self): + """ Increment total and session iterations """ + self.iterations += 1 + self.sessions[self.session_id]["iterations"] += 1 + + def load(self): + """ Load state file """ + logger.debug("Loading State") + try: + with open(self.filename, "rb") as inp: + state = self.serializer.unmarshal(inp.read().decode("utf-8")) + self.sessions = state.get("sessions", dict()) + self.lowest_avg_loss = state.get("lowest_avg_loss", dict()) + self.iterations = state.get("iterations", 0) + self.training_size = state.get("training_size", 256) + self.inputs = state.get("inputs", dict()) + self.config = state.get("config", dict()) + logger.debug("Loaded state: %s", state) + self.replace_config() + except IOError as err: + logger.warning("No existing state file found. Generating.") + logger.debug("IOError: %s", str(err)) + except JSONDecodeError as err: + logger.debug("JSONDecodeError: %s:", str(err)) + + def save(self, should_backup=False): + """ Save iteration number to state file """ + logger.debug("Saving State") + if should_backup: + self.backup() + try: + with open(self.filename, "wb") as out: + state = {"sessions": self.sessions, + "lowest_avg_loss": self.lowest_avg_loss, + "iterations": self.iterations, + "inputs": self.inputs, + "training_size": self.training_size, + "config": _CONFIG} + state_json = self.serializer.marshal(state) + out.write(state_json.encode("utf-8")) + except IOError as err: + logger.error("Unable to save model state: %s", str(err.strerror)) + logger.debug("Saved State") + + def backup(self): + """ Backup state file """ + origfile = self.filename + backupfile = origfile + ".bk" + logger.debug("Backing up: '%s' to '%s'", origfile, backupfile) + if os.path.exists(backupfile): + os.remove(backupfile) + if os.path.exists(origfile): + os.rename(origfile, backupfile) + + def replace_config(self): + """ Replace the loaded config with the one contained within the state file """ + global _CONFIG # pylint: disable=global-statement + # Add any new items to state config for legacy purposes + for key, val in _CONFIG.items(): + if key not in self.config.keys(): + logger.info("Adding new config item to state file: '%s': '%s'", key, val) + self.config[key] = val + logger.debug("Replacing config. Old config: %s", _CONFIG) + _CONFIG = self.config + logger.debug("Replaced config. New config: %s", _CONFIG) + logger.info("Using configuration saved in state file") diff --git a/plugins/train/model/dfaker.py b/plugins/train/model/dfaker.py new file mode 100644 index 0000000000..d097e20ac4 --- /dev/null +++ b/plugins/train/model/dfaker.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +""" DFaker Model + Based on the dfaker model: https://github.com/dfaker """ + + +from keras.initializers import RandomNormal +from keras.layers import Conv2D, Input +from keras.models import Model as KerasModel + +from .original import logger, Model as OriginalModel + + +class Model(OriginalModel): + """ Improved Autoeencoder Model """ + def __init__(self, *args, **kwargs): + logger.debug("Initializing %s: (args: %s, kwargs: %s", + self.__class__.__name__, args, kwargs) + kwargs["input_shape"] = (64, 64, 3) + kwargs["encoder_dim"] = 1024 + self.kernel_initializer = RandomNormal(0, 0.02) + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def build_autoencoders(self): + """ Initialize Dfaker model """ + logger.debug("Initializing model") + inputs = [Input(shape=self.input_shape, name="face")] + if self.config.get("mask_type", None): + mask_shape = (self.input_shape[0] * 2, self.input_shape[1] * 2, 1) + inputs.append(Input(shape=mask_shape, name="mask")) + + for side in ("a", "b"): + decoder = self.networks["decoder_{}".format(side)].network + output = decoder(self.networks["encoder"].network(inputs[0])) + autoencoder = KerasModel(inputs, output) + self.add_predictor(side, autoencoder) + logger.debug("Initialized model") + + def decoder(self): + """ Decoder Network """ + input_ = Input(shape=(8, 8, 512)) + var_x = input_ + + var_x = self.blocks.upscale(var_x, 512, res_block_follows=True) + var_x = self.blocks.res_block(var_x, 512, kernel_initializer=self.kernel_initializer) + var_x = self.blocks.upscale(var_x, 256, res_block_follows=True) + var_x = self.blocks.res_block(var_x, 256, kernel_initializer=self.kernel_initializer) + var_x = self.blocks.upscale(var_x, 128, res_block_follows=True) + var_x = self.blocks.res_block(var_x, 128, kernel_initializer=self.kernel_initializer) + var_x = self.blocks.upscale(var_x, 64) + var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + outputs = [var_x] + + if self.config.get("mask_type", None): + var_y = input_ + var_y = self.blocks.upscale(var_y, 512) + var_y = self.blocks.upscale(var_y, 256) + var_y = self.blocks.upscale(var_y, 128) + var_y = self.blocks.upscale(var_y, 64) + var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + outputs.append(var_y) + return KerasModel([input_], outputs=outputs) diff --git a/plugins/train/model/dfl_h128.py b/plugins/train/model/dfl_h128.py new file mode 100644 index 0000000000..7e780c78cc --- /dev/null +++ b/plugins/train/model/dfl_h128.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +""" DeepFakesLab H128 Model + Based on https://github.com/iperov/DeepFaceLab +""" + +from keras.layers import Conv2D, Dense, Flatten, Input, Reshape +from keras.models import Model as KerasModel + +from .original import logger, Model as OriginalModel + + +class Model(OriginalModel): + """ Low Memory version of Original Faceswap Model """ + def __init__(self, *args, **kwargs): + logger.debug("Initializing %s: (args: %s, kwargs: %s", + self.__class__.__name__, args, kwargs) + + kwargs["input_shape"] = (128, 128, 3) + kwargs["encoder_dim"] = 256 if self.config["lowmem"] else 512 + + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def encoder(self): + """ DFL H128 Encoder """ + input_ = Input(shape=self.input_shape) + var_x = input_ + var_x = self.blocks.conv(var_x, 128) + var_x = self.blocks.conv(var_x, 256) + var_x = self.blocks.conv(var_x, 512) + var_x = self.blocks.conv(var_x, 1024) + var_x = Dense(self.encoder_dim)(Flatten()(var_x)) + var_x = Dense(8 * 8 * self.encoder_dim)(var_x) + var_x = Reshape((8, 8, self.encoder_dim))(var_x) + var_x = self.blocks.upscale(var_x, self.encoder_dim) + return KerasModel(input_, var_x) + + def decoder(self): + """ DFL H128 Decoder """ + input_ = Input(shape=(16, 16, self.encoder_dim)) + var = input_ + var = self.blocks.upscale(var, self.encoder_dim) + var = self.blocks.upscale(var, self.encoder_dim // 2) + var = self.blocks.upscale(var, self.encoder_dim // 4) + + # Face + var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var) + outputs = [var_x] + # Mask + if self.config.get("mask_type", None): + var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var) + outputs.append(var_y) + return KerasModel(input_, outputs=outputs) diff --git a/plugins/train/model/iae.py b/plugins/train/model/iae.py new file mode 100644 index 0000000000..2ab622ce35 --- /dev/null +++ b/plugins/train/model/iae.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" Improved autoencoder for faceswap """ + +from keras.layers import Concatenate, Conv2D, Dense, Flatten, Input, Reshape +from keras.models import Model as KerasModel + +from ._base import ModelBase, logger + + +class Model(ModelBase): + """ Improved Autoeencoder Model """ + def __init__(self, *args, **kwargs): + logger.debug("Initializing %s: (args: %s, kwargs: %s", + self.__class__.__name__, args, kwargs) + kwargs["input_shape"] = (64, 64, 3) + kwargs["encoder_dim"] = 1024 + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def add_networks(self): + """ Add the IAE model weights """ + logger.debug("Adding networks") + self.add_network("encoder", None, self.encoder()) + self.add_network("decoder", None, self.decoder()) + self.add_network("intermediate", "a", self.intermediate()) + self.add_network("intermediate", "b", self.intermediate()) + self.add_network("inter", None, self.intermediate()) + logger.debug("Added networks") + + def build_autoencoders(self): + """ Initialize IAE model """ + logger.debug("Initializing model") + inputs = [Input(shape=self.input_shape, name="face")] + if self.config.get("mask_type", "none") != "none": + mask_shape = (self.input_shape[:2] + (1, )) + inputs.append(Input(shape=mask_shape, name="mask")) + + decoder = self.networks["decoder"].network + encoder = self.networks["encoder"].network + inter_both = self.networks["inter"].network + for side in ("a", "b"): + inter_side = self.networks["intermediate_{}".format(side)].network + output = decoder(Concatenate()([inter_side(encoder(inputs[0])), + inter_both(encoder(inputs[0]))])) + + autoencoder = KerasModel(inputs, output) + self.add_predictor(side, autoencoder) + logger.debug("Initialized model") + + def encoder(self): + """ Encoder Network """ + input_ = Input(shape=self.input_shape) + var_x = input_ + var_x = self.blocks.conv(var_x, 128) + var_x = self.blocks.conv(var_x, 266) + var_x = self.blocks.conv(var_x, 512) + var_x = self.blocks.conv(var_x, 1024) + var_x = Flatten()(var_x) + return KerasModel(input_, var_x) + + def intermediate(self): + """ Intermediate Network """ + input_ = Input(shape=(None, 4 * 4 * 1024)) + var_x = input_ + var_x = Dense(self.encoder_dim)(var_x) + var_x = Dense(4 * 4 * int(self.encoder_dim/2))(var_x) + var_x = Reshape((4, 4, int(self.encoder_dim/2)))(var_x) + return KerasModel(input_, var_x) + + def decoder(self): + """ Decoder Network """ + input_ = Input(shape=(4, 4, self.encoder_dim)) + var_x = input_ + var_x = self.blocks.upscale(var_x, 512) + var_x = self.blocks.upscale(var_x, 256) + var_x = self.blocks.upscale(var_x, 128) + var_x = self.blocks.upscale(var_x, 64) + var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x) + outputs = [var_x] + + if self.config.get("mask_type", None): + var_y = Conv2D(1, kernel_size=5, padding="same", activation="sigmoid")(var_x) + outputs.append(var_y) + return KerasModel(input_, outputs=outputs) diff --git a/plugins/train/model/original.py b/plugins/train/model/original.py new file mode 100644 index 0000000000..e065aa00c8 --- /dev/null +++ b/plugins/train/model/original.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" Original Model + Based on the original https://www.reddit.com/r/deepfakes/ + code sample + contribs """ + +from keras.layers import Conv2D, Dense, Flatten, Input, Reshape + +from keras.models import Model as KerasModel + +from ._base import ModelBase, logger + + +class Model(ModelBase): + """ Original Faceswap Model """ + def __init__(self, *args, **kwargs): + logger.debug("Initializing %s: (args: %s, kwargs: %s", + self.__class__.__name__, args, kwargs) + + if "input_shape" not in kwargs: + kwargs["input_shape"] = (64, 64, 3) + if "encoder_dim" not in kwargs: + kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024 + + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def add_networks(self): + """ Add the original model weights """ + logger.debug("Adding networks") + self.add_network("decoder", "a", self.decoder()) + self.add_network("decoder", "b", self.decoder()) + self.add_network("encoder", None, self.encoder()) + logger.debug("Added networks") + + def build_autoencoders(self): + """ Initialize original model """ + logger.debug("Initializing model") + inputs = [Input(shape=self.input_shape, name="face")] + if self.config.get("mask_type", None): + mask_shape = (self.input_shape[:2] + (1, )) + inputs.append(Input(shape=mask_shape, name="mask")) + + for side in ("a", "b"): + logger.debug("Adding Autoencoder. Side: %s", side) + decoder = self.networks["decoder_{}".format(side)].network + output = decoder(self.networks["encoder"].network(inputs[0])) + autoencoder = KerasModel(inputs, output) + self.add_predictor(side, autoencoder) + logger.debug("Initialized model") + + def encoder(self): + """ Encoder Network """ + input_ = Input(shape=self.input_shape) + var_x = input_ + var_x = self.blocks.conv(var_x, 128) + var_x = self.blocks.conv(var_x, 256) + var_x = self.blocks.conv(var_x, 512) + if not self.config.get("lowmem", False): + var_x = self.blocks.conv(var_x, 1024) + var_x = Dense(self.encoder_dim)(Flatten()(var_x)) + var_x = Dense(4 * 4 * 1024)(var_x) + var_x = Reshape((4, 4, 1024))(var_x) + var_x = self.blocks.upscale(var_x, 512) + return KerasModel(input_, var_x) + + def decoder(self): + """ Decoder Network """ + input_ = Input(shape=(8, 8, 512)) + var_x = input_ + var_x = self.blocks.upscale(var_x, 256) + var_x = self.blocks.upscale(var_x, 128) + var_x = self.blocks.upscale(var_x, 64) + var_x = Conv2D(3, kernel_size=5, padding="same", activation="sigmoid")(var_x) + outputs = [var_x] + + if self.config.get("mask_type", None): + var_y = input_ + var_y = self.blocks.upscale(var_y, 256) + var_y = self.blocks.upscale(var_y, 128) + var_y = self.blocks.upscale(var_y, 64) + var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + outputs.append(var_y) + return KerasModel(input_, outputs=outputs) diff --git a/plugins/train/model/unbalanced.py b/plugins/train/model/unbalanced.py new file mode 100644 index 0000000000..85ede86fc9 --- /dev/null +++ b/plugins/train/model/unbalanced.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +""" Unbalanced Model + Based on the original https://www.reddit.com/r/deepfakes/ + code sample + contribs """ + +from keras.initializers import RandomNormal +from keras.layers import Conv2D, Dense, Flatten, Input, Reshape, SpatialDropout2D +from keras.models import Model as KerasModel + +from .original import logger, Model as OriginalModel + + +class Model(OriginalModel): + """ Unbalanced Faceswap Model """ + def __init__(self, *args, **kwargs): + logger.debug("Initializing %s: (args: %s, kwargs: %s", + self.__class__.__name__, args, kwargs) + + self.lowmem = self.config.get("lowmem", False) + kwargs["input_shape"] = (self.config["input_size"], self.config["input_size"], 3) + kwargs["encoder_dim"] = 512 if self.lowmem else self.config["nodes"] + self.kernel_initializer = RandomNormal(0, 0.02) + + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def add_networks(self): + """ Add the original model weights """ + logger.debug("Adding networks") + self.add_network("decoder", "a", self.decoder_a()) + self.add_network("decoder", "b", self.decoder_b()) + self.add_network("encoder", None, self.encoder()) + logger.debug("Added networks") + + def encoder(self): + """ Unbalanced Encoder """ + kwargs = dict(kernel_initializer=self.kernel_initializer) + encoder_complexity = 128 if self.lowmem else self.config["complexity_encoder"] + dense_dim = 384 if self.lowmem else 512 + dense_shape = self.input_shape[0] // 16 + input_ = Input(shape=self.input_shape) + + var_x = input_ + var_x = self.blocks.conv(var_x, encoder_complexity, use_instance_norm=True, **kwargs) + var_x = self.blocks.conv(var_x, encoder_complexity * 2, use_instance_norm=True, **kwargs) + var_x = self.blocks.conv(var_x, encoder_complexity * 4, **kwargs) + var_x = self.blocks.conv(var_x, encoder_complexity * 6, **kwargs) + var_x = self.blocks.conv(var_x, encoder_complexity * 8, **kwargs) + var_x = Dense(self.encoder_dim, + kernel_initializer=self.kernel_initializer)(Flatten()(var_x)) + var_x = Dense(dense_shape * dense_shape * dense_dim, + kernel_initializer=self.kernel_initializer)(var_x) + var_x = Reshape((dense_shape, dense_shape, dense_dim))(var_x) + return KerasModel(input_, var_x) + + def decoder_a(self): + """ Decoder for side A """ + kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer) + decoder_complexity = 320 if self.lowmem else self.config["complexity_decoder_a"] + dense_dim = 384 if self.lowmem else 512 + decoder_shape = self.input_shape[0] // 16 + input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim)) + + var_x = input_ + + var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs) + var_x = SpatialDropout2D(0.25)(var_x) + var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs) + if self.lowmem: + var_x = SpatialDropout2D(0.15)(var_x) + else: + var_x = SpatialDropout2D(0.25)(var_x) + var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs) + var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs) + var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + outputs = [var_x] + + if self.config.get("mask_type", None): + var_y = input_ + var_y = self.blocks.upscale(var_y, decoder_complexity) + var_y = self.blocks.upscale(var_y, decoder_complexity) + var_y = self.blocks.upscale(var_y, decoder_complexity // 2) + var_y = self.blocks.upscale(var_y, decoder_complexity // 4) + var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + outputs.append(var_y) + return KerasModel(input_, outputs=outputs) + + def decoder_b(self): + """ Decoder for side B """ + kwargs = dict(kernel_size=5, kernel_initializer=self.kernel_initializer) + dense_dim = 384 if self.lowmem else self.config["complexity_decoder_b"] + decoder_complexity = 384 if self.lowmem else 512 + decoder_shape = self.input_shape[0] // 16 + input_ = Input(shape=(decoder_shape, decoder_shape, dense_dim)) + + var_x = input_ + if self.lowmem: + var_x = self.blocks.upscale(var_x, decoder_complexity, **kwargs) + var_x = self.blocks.upscale(var_x, decoder_complexity // 2, **kwargs) + var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs) + var_x = self.blocks.upscale(var_x, decoder_complexity // 8, **kwargs) + else: + var_x = self.blocks.upscale(var_x, decoder_complexity, + res_block_follows=True, **kwargs) + var_x = self.blocks.res_block(var_x, decoder_complexity, + kernel_initializer=self.kernel_initializer) + var_x = self.blocks.upscale(var_x, decoder_complexity, + res_block_follows=True, **kwargs) + var_x = self.blocks.res_block(var_x, decoder_complexity, + kernel_initializer=self.kernel_initializer) + var_x = self.blocks.upscale(var_x, decoder_complexity // 2, + res_block_follows=True, **kwargs) + var_x = self.blocks.res_block(var_x, decoder_complexity // 2, + kernel_initializer=self.kernel_initializer) + var_x = self.blocks.upscale(var_x, decoder_complexity // 4, **kwargs) + var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + outputs = [var_x] + + if self.config.get("mask_type", None): + var_y = input_ + var_y = self.blocks.upscale(var_y, decoder_complexity) + if not self.lowmem: + var_y = self.blocks.upscale(var_y, decoder_complexity) + var_y = self.blocks.upscale(var_y, decoder_complexity // 2) + var_y = self.blocks.upscale(var_y, decoder_complexity // 4) + if self.lowmem: + var_y = self.blocks.upscale(var_y, decoder_complexity // 8) + var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + outputs.append(var_y) + return KerasModel(input_, outputs=outputs) diff --git a/plugins/train/model/villain.py b/plugins/train/model/villain.py new file mode 100644 index 0000000000..c4d18ac1cd --- /dev/null +++ b/plugins/train/model/villain.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" Original - VillainGuy model + Based on the original https://www.reddit.com/r/deepfakes/ code sample + contribs + Adapted from a model by VillainGuy (https://github.com/VillainGuy) """ + +from keras.initializers import RandomNormal +from keras.layers import add, Conv2D, Dense, Flatten, Input, Reshape +from keras.models import Model as KerasModel + +from lib.model.layers import PixelShuffler +from .original import logger, Model as OriginalModel + + +class Model(OriginalModel): + """ Villain Faceswap Model """ + def __init__(self, *args, **kwargs): + logger.debug("Initializing %s: (args: %s, kwargs: %s", + self.__class__.__name__, args, kwargs) + + kwargs["input_shape"] = (128, 128, 3) + kwargs["encoder_dim"] = 512 if self.config["lowmem"] else 1024 + self.kernel_initializer = RandomNormal(0, 0.02) + + super().__init__(*args, **kwargs) + logger.debug("Initialized %s", self.__class__.__name__) + + def encoder(self): + """ Encoder Network """ + kwargs = dict(kernel_initializer=self.kernel_initializer) + input_ = Input(shape=self.input_shape) + in_conv_filters = self.input_shape[0] + if self.input_shape[0] > 128: + in_conv_filters = 128 + (self.input_shape[0] - 128) // 4 + dense_shape = self.input_shape[0] // 16 + + var_x = self.blocks.conv(input_, in_conv_filters, res_block_follows=True, **kwargs) + tmp_x = var_x + res_cycles = 8 if self.config.get("lowmem", False) else 16 + for _ in range(res_cycles): + nn_x = self.blocks.res_block(var_x, 128, **kwargs) + var_x = nn_x + # consider adding scale before this layer to scale the residual chain + var_x = add([var_x, tmp_x]) + var_x = self.blocks.conv(var_x, 128, **kwargs) + var_x = PixelShuffler()(var_x) + var_x = self.blocks.conv(var_x, 128, **kwargs) + var_x = PixelShuffler()(var_x) + var_x = self.blocks.conv(var_x, 128, **kwargs) + var_x = self.blocks.conv_sep(var_x, 256, **kwargs) + var_x = self.blocks.conv(var_x, 512, **kwargs) + if not self.config.get("lowmem", False): + var_x = self.blocks.conv_sep(var_x, 1024, **kwargs) + + var_x = Dense(self.encoder_dim, **kwargs)(Flatten()(var_x)) + var_x = Dense(dense_shape * dense_shape * 1024, **kwargs)(var_x) + var_x = Reshape((dense_shape, dense_shape, 1024))(var_x) + var_x = self.blocks.upscale(var_x, 512, **kwargs) + return KerasModel(input_, var_x) + + def decoder(self): + """ Decoder Network """ + kwargs = dict(kernel_initializer=self.kernel_initializer) + decoder_shape = self.input_shape[0] // 8 + input_ = Input(shape=(decoder_shape, decoder_shape, 512)) + + var_x = input_ + var_x = self.blocks.upscale(var_x, 512, res_block_follows=True, **kwargs) + var_x = self.blocks.res_block(var_x, 512, **kwargs) + var_x = self.blocks.upscale(var_x, 256, res_block_follows=True, **kwargs) + var_x = self.blocks.res_block(var_x, 256, **kwargs) + var_x = self.blocks.upscale(var_x, self.input_shape[0], res_block_follows=True, **kwargs) + var_x = self.blocks.res_block(var_x, self.input_shape[0], **kwargs) + var_x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(var_x) + outputs = [var_x] + + if self.config.get("mask_type", None): + var_y = input_ + var_y = self.blocks.upscale(var_y, 512) + var_y = self.blocks.upscale(var_y, 256) + var_y = self.blocks.upscale(var_y, self.input_shape[0]) + var_y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(var_y) + outputs.append(var_y) + return KerasModel(input_, outputs=outputs) diff --git a/plugins/train/trainer/__init__.py b/plugins/train/trainer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/train/trainer/_base.py b/plugins/train/trainer/_base.py new file mode 100644 index 0000000000..c8ba688b8c --- /dev/null +++ b/plugins/train/trainer/_base.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python3 + + +""" Base Trainer Class for Faceswap + + Trainers should be inherited from this class. + + A training_opts dictionary can be set in the corresponding model. + Accepted values: + alignments: dict containing paths to alignments files for keys 'a' and 'b' + preview_scaling: How much to scale the preview out by + training_size: Size of the training images + coverage_ratio: Ratio of face to be cropped out for training + mask_type: Type of mask to use. See lib.model.masks for valid mask names. + Set to None for not used + no_logs: Disable tensorboard logging + warp_to_landmarks: Use random_warp_landmarks instead of random_warp + no_flip: Don't perform a random flip on the image +""" + +import logging +import os +import time + +import cv2 +import numpy as np + +from tensorflow import keras as tf_keras + +from lib.alignments import Alignments +from lib.faces_detect import DetectedFace +from lib.training_data import TrainingDataGenerator, stack_images +from lib.utils import get_folder, get_image_paths + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class TrainerBase(): + """ Base Trainer """ + + def __init__(self, model, images, batch_size): + logger.debug("Initializing %s: (model: '%s', batch_size: %s)", + self.__class__.__name__, model, batch_size) + self.batch_size = batch_size + self.model = model + self.model.state.add_session_batchsize(batch_size) + self.images = images + + self.process_training_opts() + + self.batchers = {side: Batcher(side, + images[side], + self.model, + self.use_mask, + batch_size) + for side in images.keys()} + + self.tensorboard = self.set_tensorboard() + self.samples = Samples(self.model, + self.use_mask, + self.model.training_opts["coverage_ratio"], + self.model.training_opts["preview_scaling"]) + self.timelapse = Timelapse(self.model, + self.use_mask, + self.model.training_opts["coverage_ratio"], + self.batchers) + logger.debug("Initialized %s", self.__class__.__name__) + + @property + def timestamp(self): + """ Standardised timestamp for loss reporting """ + return time.strftime("%H:%M:%S") + + @property + def landmarks_required(self): + """ Return True if Landmarks are required """ + opts = self.model.training_opts + retval = bool(opts.get("mask_type", None) or opts["warp_to_landmarks"]) + logger.debug(retval) + return retval + + @property + def use_mask(self): + """ Return True if a mask is requested """ + retval = bool(self.model.training_opts.get("mask_type", None)) + logger.debug(retval) + return retval + + def process_training_opts(self): + """ Override for processing model specific training options """ + logger.debug(self.model.training_opts) + if self.landmarks_required: + landmarks = Landmarks(self.model.training_opts).landmarks + self.model.training_opts["landmarks"] = landmarks + + def set_tensorboard(self): + """ Set up tensorboard callback """ + if self.model.training_opts["no_logs"]: + logger.verbose("TensorBoard logging disabled") + return None + + logger.debug("Enabling TensorBoard Logging") + tensorboard = dict() + for side in self.images.keys(): + logger.debug("Setting up TensorBoard Logging. Side: %s", side) + log_dir = os.path.join(str(self.model.model_dir), + "{}_logs".format(self.model.name), + side, + "session_{}".format(self.model.state.session_id)) + tbs = tf_keras.callbacks.TensorBoard(log_dir=log_dir, + histogram_freq=0, # Must be 0 or hangs + batch_size=self.batch_size, + write_graph=True, + write_grads=True) + tbs.set_model(self.model.predictors[side]) + tensorboard[side] = tbs + logger.info("Enabled TensorBoard Logging") + return tensorboard + + def print_loss(self, loss): + """ Override for specific model loss formatting """ + output = list() + for side in sorted(list(loss.keys())): + display = ", ".join(["{}_{}: {:.5f}".format(self.model.state.loss_names[side][idx], + side.capitalize(), + this_loss) + for idx, this_loss in enumerate(loss[side])]) + output.append(display) + print("[{}] [#{:05d}] {}, {}".format( + self.timestamp, self.model.iterations, output[0], output[1]), end='\r') + + def train_one_step(self, viewer, timelapse_kwargs): + """ Train a batch """ + logger.trace("Training one step: (iteration: %s)", self.model.iterations) + is_preview_iteration = False if viewer is None else True + loss = dict() + for side, batcher in self.batchers.items(): + loss[side] = batcher.train_one_batch(is_preview_iteration) + if not is_preview_iteration: + continue + self.samples.images[side] = batcher.compile_sample(self.batch_size) + if timelapse_kwargs: + self.timelapse.get_sample(side, timelapse_kwargs) + + self.model.state.increment_iterations() + + for side, side_loss in loss.items(): + self.store_history(side, side_loss) + self.log_tensorboard(side, side_loss) + self.print_loss(loss) + + if viewer is not None: + viewer(self.samples.show_sample(), + "Training - 'S': Save Now. 'ENTER': Save and Quit") + + if timelapse_kwargs is not None: + self.timelapse.output_timelapse() + + def store_history(self, side, loss): + """ Store the history of this step """ + logger.trace("Updating loss history: '%s'", side) + self.model.history[side].append(loss[0]) # Either only loss or total loss + logger.trace("Updated loss history: '%s'", side) + + def log_tensorboard(self, side, loss): + """ Log loss to TensorBoard log """ + if not self.tensorboard: + return + logger.trace("Updating TensorBoard log: '%s'", side) + logs = {log[0]: log[1] + for log in zip(self.model.state.loss_names[side], loss)} + self.tensorboard[side].on_batch_end(self.model.state.iterations, logs) + logger.trace("Updated TensorBoard log: '%s'", side) + + def clear_tensorboard(self): + """ Indicate training end to Tensorboard """ + if not self.tensorboard: + return + for side, tensorboard in self.tensorboard.items(): + logger.debug("Ending Tensorboard. Side: '%s'", side) + tensorboard.on_train_end(None) + + +class Batcher(): + """ Batch images from a single side """ + def __init__(self, side, images, model, use_mask, batch_size): + logger.debug("Initializing %s: side: '%s', num_images: %s, batch_size: %s)", + self.__class__.__name__, side, len(images), batch_size) + self.model = model + self.use_mask = use_mask + self.side = side + self.target = None + self.samples = None + self.mask = None + + self.feed = self.load_generator().minibatch_ab(images, batch_size, self.side) + self.timelapse_feed = None + + def load_generator(self): + """ Pass arguments to TrainingDataGenerator and return object """ + logger.debug("Loading generator: %s", self.side) + input_size = self.model.input_shape[0] + output_size = self.model.output_shape[0] + logger.debug("input_size: %s, output_size: %s", input_size, output_size) + generator = TrainingDataGenerator(input_size, output_size, self.model.training_opts) + return generator + + def train_one_batch(self, is_preview_iteration): + """ Train a batch """ + logger.trace("Training one step: (side: %s)", self.side) + batch = self.get_next(is_preview_iteration) + loss = self.model.predictors[self.side].train_on_batch(*batch) + loss = loss if isinstance(loss, list) else [loss] + return loss + + def get_next(self, is_preview_iteration): + """ Return the next batch from the generator + Items should come out as: (warped, target [, mask]) """ + batch = next(self.feed) + self.samples = batch[0] if is_preview_iteration else None + batch = batch[1:] # Remove full size samples from batch + if self.use_mask: + batch = self.compile_mask(batch) + self.target = batch[1] if is_preview_iteration else None + return batch + + def compile_mask(self, batch): + """ Compile the mask into training data """ + logger.trace("Compiling Mask: (side: '%s')", self.side) + mask = batch[-1] + retval = list() + for idx in range(len(batch) - 1): + image = batch[idx] + retval.append([image, mask]) + return retval + + def compile_sample(self, batch_size, samples=None, images=None): + """ Training samples to display in the viewer """ + num_images = self.model.training_opts.get("preview_images", 14) + num_images = min(batch_size, num_images) + logger.debug("Compiling samples: (side: '%s', samples: %s)", self.side, num_images) + images = images if images is not None else self.target + samples = [samples[0:num_images]] if samples is not None else [self.samples[0:num_images]] + if self.use_mask: + retval = [tgt[0:num_images] for tgt in images] + else: + retval = [images[0:num_images]] + retval = samples + retval + return retval + + def compile_timelapse_sample(self): + """ Timelapse samples """ + batch = next(self.timelapse_feed) + samples = batch[0] + batch = batch[1:] # Remove full size samples from batch + batchsize = len(samples) + if self.use_mask: + batch = self.compile_mask(batch) + images = batch[1] + sample = self.compile_sample(batchsize, samples=samples, images=images) + return sample + + def set_timelapse_feed(self, images, batchsize): + """ Set the timelapse dictionary """ + logger.debug("Setting timelapse feed: (side: '%s', input_images: '%s', batchsize: %s)", + self.side, images, batchsize) + self.timelapse_feed = self.load_generator().minibatch_ab(images[:batchsize], + batchsize, self.side, + do_shuffle=False, + is_timelapse=True) + logger.debug("Set timelapse feed") + + +class Samples(): + """ Display samples for preview and timelapse """ + def __init__(self, model, use_mask, coverage_ratio, scaling=1.0): + logger.debug("Initializing %s: model: '%s', use_mask: %s, coverage_ratio: %s)", + self.__class__.__name__, model, use_mask, coverage_ratio) + self.model = model + self.use_mask = use_mask + self.images = dict() + self.coverage_ratio = coverage_ratio + self.scaling = scaling + logger.debug("Initialized %s", self.__class__.__name__) + + def show_sample(self): + """ Display preview data """ + logger.debug("Showing sample") + feeds = dict() + figures = dict() + headers = dict() + for side, samples in self.images.items(): + faces = samples[1] + if self.model.input_shape[0] / faces.shape[1] != 1.0: + feeds[side] = self.resize_sample(side, faces, self.model.input_shape[0]) + feeds[side] = feeds[side].reshape((-1, ) + self.model.input_shape) + else: + feeds[side] = faces + if self.use_mask: + mask = samples[-1] + feeds[side] = [feeds[side], mask] + + preds = self.get_predictions(feeds["a"], feeds["b"]) + + for side, samples in self.images.items(): + other_side = "a" if side == "b" else "b" + predictions = [preds["{}_{}".format(side, side)], + preds["{}_{}".format(other_side, side)]] + display = self.to_full_frame(side, samples, predictions) + headers[side] = self.get_headers(side, other_side, display[0].shape[1]) + figures[side] = np.stack([display[0], display[1], display[2], ], axis=1) + if self.images[side][0].shape[0] % 2 == 1: + figures[side] = np.concatenate([figures[side], + np.expand_dims(figures[side][0], 0)]) + + width = 4 + side_cols = width // 2 + if side_cols != 1: + headers = self.duplicate_headers(headers, side_cols) + + header = np.concatenate([headers["a"], headers["b"]], axis=1) + figure = np.concatenate([figures["a"], figures["b"]], axis=0) + height = int(figure.shape[0] / width) + figure = figure.reshape((width, height) + figure.shape[1:]) + figure = stack_images(figure) + figure = np.vstack((header, figure)) + + logger.debug("Compiled sample") + return np.clip(figure * 255, 0, 255).astype('uint8') + + @staticmethod + def resize_sample(side, sample, target_size): + """ Resize samples where predictor expects different shape from processed image """ + scale = target_size / sample.shape[1] + if scale == 1.0: + return sample + logger.debug("Resizing sample: (side: '%s', sample.shape: %s, target_size: %s, scale: %s)", + side, sample.shape, target_size, scale) + interpn = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA # pylint: disable=no-member + retval = np.array([cv2.resize(img, # pylint: disable=no-member + (target_size, target_size), + interpn) + for img in sample]) + logger.debug("Resized sample: (side: '%s' shape: %s)", side, retval.shape) + return retval + + def get_predictions(self, feed_a, feed_b): + """ Return the sample predictions from the model """ + logger.debug("Getting Predictions") + preds = dict() + preds["a_a"] = self.model.predictors["a"].predict(feed_a) + preds["b_a"] = self.model.predictors["b"].predict(feed_a) + preds["a_b"] = self.model.predictors["a"].predict(feed_b) + preds["b_b"] = self.model.predictors["b"].predict(feed_b) + + # Get the returned image from predictors that emit multiple items + if not isinstance(preds["a_a"], np.ndarray): + for key, val in preds.items(): + preds[key] = val[0] + logger.debug("Returning predictions: %s", {key: val.shape for key, val in preds.items()}) + return preds + + def to_full_frame(self, side, samples, predictions): + """ Patch the images into the full frame """ + logger.debug("side: '%s', number of sample arrays: %s, prediction.shapes: %s)", + side, len(samples), [pred.shape for pred in predictions]) + full, faces = samples[:2] + images = [faces] + predictions + full_size = full.shape[1] + target_size = int(full_size * self.coverage_ratio) + if target_size != full_size: + frame = self.frame_overlay(full, target_size, (0, 0, 255)) + + if self.use_mask: + images = self.compile_masked(images, samples[-1]) + images = [self.resize_sample(side, image, target_size) for image in images] + if target_size != full_size: + images = [self.overlay_foreground(frame, image) for image in images] + if self.scaling != 1.0: + new_size = int(full_size * self.scaling) + images = [self.resize_sample(side, image, new_size) for image in images] + return images + + @staticmethod + def frame_overlay(images, target_size, color): + """ Add roi frame to a backfround image """ + logger.debug("full_size: %s, target_size: %s, color: %s", + images.shape[1], target_size, color) + new_images = list() + full_size = images.shape[1] + padding = (full_size - target_size) // 2 + length = target_size // 4 + t_l, b_r = (padding, full_size - padding) + for img in images: + cv2.rectangle(img, # pylint: disable=no-member + (t_l, t_l), + (t_l + length, t_l + length), + color, + 3) + cv2.rectangle(img, # pylint: disable=no-member + (b_r, t_l), + (b_r - length, t_l + length), + color, + 3) + cv2.rectangle(img, # pylint: disable=no-member + (b_r, b_r), + (b_r - length, + b_r - length), + color, + 3) + cv2.rectangle(img, # pylint: disable=no-member + (t_l, b_r), + (t_l + length, b_r - length), + color, + 3) + new_images.append(img) + retval = np.array(new_images) + logger.debug("Overlayed background. Shape: %s", retval.shape) + return retval + + @staticmethod + def compile_masked(faces, masks): + """ Add the mask to the faces for masked preview """ + retval = list() + masks3 = np.tile(1 - np.rint(masks), 3) + for mask in masks3: + mask[np.where((mask == [1., 1., 1.]).all(axis=2))] = [0., 0., 1.] + for previews in faces: + images = np.array([cv2.addWeighted(img, 1.0, # pylint: disable=no-member + masks3[idx], 0.3, + 0) + for idx, img in enumerate(previews)]) + retval.append(images) + logger.debug("masked shapes: %s", [faces.shape for faces in retval]) + return retval + + @staticmethod + def overlay_foreground(backgrounds, foregrounds): + """ Overlay the training images into the center of the background """ + offset = (backgrounds.shape[1] - foregrounds.shape[1]) // 2 + new_images = list() + for idx, img in enumerate(backgrounds): + img[offset:offset + foregrounds[idx].shape[0], + offset:offset + foregrounds[idx].shape[1]] = foregrounds[idx] + new_images.append(img) + retval = np.array(new_images) + logger.debug("Overlayed foreground. Shape: %s", retval.shape) + return retval + + def get_headers(self, side, other_side, width): + """ Set headers for images """ + logger.debug("side: '%s', other_side: '%s', width: %s", + side, other_side, width) + side = side.upper() + other_side = other_side.upper() + height = int(64 * self.scaling) + total_width = width * 3 + logger.debug("height: %s, total_width: %s", height, total_width) + font = cv2.FONT_HERSHEY_SIMPLEX # pylint: disable=no-member + texts = ["Target {}".format(side), + "{} > {}".format(side, side), + "{} > {}".format(side, other_side)] + text_sizes = [cv2.getTextSize(texts[idx], # pylint: disable=no-member + font, + self.scaling, + 1)[0] + for idx in range(len(texts))] + text_y = int((height + text_sizes[0][1]) / 2) + text_x = [int((width - text_sizes[idx][0]) / 2) + width * idx + for idx in range(len(texts))] + logger.debug("texts: %s, text_sizes: %s, text_x: %s, text_y: %s", + texts, text_sizes, text_x, text_y) + header_box = np.ones((height, total_width, 3), np.float32) + for idx, text in enumerate(texts): + cv2.putText(header_box, # pylint: disable=no-member + text, + (text_x[idx], text_y), + font, + self.scaling, + (0, 0, 0), + 1, + lineType=cv2.LINE_AA) # pylint: disable=no-member + logger.debug("header_box.shape: %s", header_box.shape) + return header_box + + @staticmethod + def duplicate_headers(headers, columns): + """ Duplicate headers for the number of columns displayed """ + for side, header in headers.items(): + duped = tuple([header for _ in range(columns)]) + headers[side] = np.concatenate(duped, axis=1) + logger.debug("side: %s header.shape: %s", side, header.shape) + return headers + + +class Timelapse(): + """ Create the timelapse """ + def __init__(self, model, use_mask, coverage_ratio, batchers): + logger.debug("Initializing %s: model: %s, use_mask: %s, coverage_ratio: %s, " + "batchers: '%s')", self.__class__.__name__, model, use_mask, + coverage_ratio, batchers) + self.samples = Samples(model, use_mask, coverage_ratio) + self.model = model + self.batchers = batchers + self.output_file = None + logger.debug("Initialized %s", self.__class__.__name__) + + def get_sample(self, side, timelapse_kwargs): + """ Perform timelapse """ + logger.debug("Getting timelapse samples: '%s'", side) + if not self.output_file: + self.setup(**timelapse_kwargs) + self.samples.images[side] = self.batchers[side].compile_timelapse_sample() + logger.debug("Got timelapse samples: '%s' - %s", side, len(self.samples.images[side])) + + def setup(self, input_a=None, input_b=None, output=None): + """ Set the timelapse output folder """ + logger.debug("Setting up timelapse") + if output is None: + output = str(get_folder(os.path.join(str(self.model.model_dir), + "{}_timelapse".format(self.model.name)))) + self.output_file = str(output) + logger.debug("Timelapse output set to '%s'", self.output_file) + + images = {"a": get_image_paths(input_a), "b": get_image_paths(input_b)} + batchsize = min(len(images["a"]), + len(images["b"]), + self.model.training_opts.get("preview_images", 14)) + for side, image_files in images.items(): + self.batchers[side].set_timelapse_feed(image_files, batchsize) + logger.debug("Set up timelapse") + + def output_timelapse(self): + """ Set the timelapse dictionary """ + logger.debug("Ouputting timelapse") + image = self.samples.show_sample() + filename = os.path.join(self.output_file, str(int(time.time())) + ".jpg") + + cv2.imwrite(filename, image) # pylint: disable=no-member + logger.debug("Created timelapse: '%s'", filename) + + +class Landmarks(): + """ Set Landmarks for training into the model's training options""" + def __init__(self, training_opts): + logger.debug("Initializing %s: (training_opts: '%s')", + self.__class__.__name__, training_opts) + self.size = training_opts.get("training_size", 256) + self.paths = training_opts["alignments"] + self.landmarks = self.get_alignments() + logger.debug("Initialized %s", self.__class__.__name__) + + def get_alignments(self): + """ Obtain the landmarks for each faceset """ + landmarks = dict() + for side, fullpath in self.paths.items(): + path, filename = os.path.split(fullpath) + filename, extension = os.path.splitext(filename) + serializer = extension[1:] + alignments = Alignments( + path, + filename=filename, + serializer=serializer) + landmarks[side] = self.transform_landmarks(alignments) + return landmarks + + def transform_landmarks(self, alignments): + """ For each face transform landmarks and return """ + landmarks = dict() + for _, faces, _, _ in alignments.yield_faces(): + for face in faces: + detected_face = DetectedFace() + detected_face.from_alignment(face) + detected_face.load_aligned(None, size=self.size, align_eyes=False) + landmarks[detected_face.hash] = detected_face.aligned_landmarks + return landmarks diff --git a/plugins/train/trainer/original.py b/plugins/train/trainer/original.py new file mode 100644 index 0000000000..14cb9afbaf --- /dev/null +++ b/plugins/train/trainer/original.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +""" Original Trainer """ + +from ._base import TrainerBase as Trainer diff --git a/scripts/convert.py b/scripts/convert.py index c0a7d53602..7b2301a712 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -15,7 +15,6 @@ from lib.multithreading import BackgroundGenerator, SpawnProcess from lib.queue_manager import queue_manager from lib.utils import get_folder, get_image_paths, hash_image_file - from plugins.plugin_loader import PluginLoader logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -43,7 +42,7 @@ def __init__(self, arguments): logger.debug("Initialized %s", self.__class__.__name__) def process(self): - """ Original & LowMem models go with Adjust or Masked converter + """ Original & LowMem models go with converter Note: GAN prediction outputs a mask + an image, while other predicts only an image. """ @@ -103,37 +102,19 @@ def load_extractor(self): def load_model(self): """ Load the model requested for conversion """ - model_name = self.args.trainer + logger.debug("Loading Model") model_dir = get_folder(self.args.model_dir) - num_gpus = self.args.gpus - - model = PluginLoader.get_model(model_name)(model_dir, num_gpus) - - if not model.load(self.args.swap_model): - logger.error("Model Not Found! A valid model " - "must be provided to continue!") - exit(1) - + model = PluginLoader.get_model(self.args.trainer)(model_dir, self.args.gpus, predict=True) + logger.debug("Loaded Model") return model def load_converter(self, model): """ Load the requested converter for conversion """ - args = self.args - conv = args.converter - + conv = self.args.converter converter = PluginLoader.get_converter(conv)( - model.converter(False), - trainer=args.trainer, - blur_size=args.blur_size, - seamless_clone=args.seamless_clone, - sharpen_image=args.sharpen_image, - mask_type=args.mask_type, - erosion_kernel_size=args.erosion_kernel_size, - match_histogram=args.match_histogram, - smooth_mask=args.smooth_mask, - avg_color_adjust=args.avg_color_adjust, - draw_transparent=args.draw_transparent) - + model.converter(self.args.swap_model), + model=model, + arguments=self.args) return converter def prepare_images(self): @@ -205,25 +186,13 @@ def convert(self, converter, item): if not skip: for face in faces: - image = self.convert_one_face(converter, image, face) + image = converter.patch_image(image, face) filename = str(self.output_dir / Path(filename).name) cv2.imwrite(filename, image) # pylint: disable=no-member except Exception as err: logger.error("Failed to convert image: '%s'. Reason: %s", filename, err) raise - def convert_one_face(self, converter, image, face): - """ Perform the conversion on the given frame for a single face """ - # TODO: This switch between 64 and 128 is a hack for now. - # We should have a separate cli option for size - size = 128 if (self.args.trainer.strip().lower() - in ('gan128', 'originalhighres')) else 64 - - image = converter.patch_image(image, - face, - size) - return image - class OptionalActions(): """ Process the optional actions for convert """ @@ -305,10 +274,8 @@ def check_skipframe(self, filename): class Legacy(): """ Update legacy alignments: - - - Add frame dimensions - Rotate landmarks and bounding boxes on legacy alignments - and remove the 'r' parameter + and remove the 'r' parameter - Add face hashes to alignments file """ def __init__(self, alignments, frames, faces_dir): @@ -319,15 +286,10 @@ def __init__(self, alignments, frames, faces_dir): def process(self, faces_dir): """ Run the rotate alignments process """ - no_dims = self.alignments.get_legacy_no_dims() rotated = self.alignments.get_legacy_rotation() hashes = self.alignments.get_legacy_no_hashes() - if not no_dims and not rotated and not hashes: + if not rotated and not hashes: return - if no_dims: - logger.info("Legacy landmarks found. Adding frame dimensions...") - self.add_dimensions(no_dims) - self.alignments.save() if rotated: logger.info("Legacy rotated frames found. Converting...") self.rotate_landmarks(rotated) @@ -337,22 +299,14 @@ def process(self, faces_dir): self.add_hashes(hashes, faces_dir) self.alignments.save() - def add_dimensions(self, no_dims): - """ Add width and height of original frame to alignments """ - for no_dim in tqdm(no_dims, desc="Adding Frame Dimensions"): - if no_dim not in self.frames.keys(): - continue - filename = self.frames[no_dim] - dims = cv2.imread(filename).shape[:2] # pylint: disable=no-member - self.alignments.add_dimensions(no_dim, dims) - def rotate_landmarks(self, rotated): """ Rotate the landmarks """ for rotate_item in tqdm(rotated, desc="Rotating Landmarks"): - if rotate_item not in self.frames.keys(): + frame = self.frames.get(rotate_item, None) + if frame is None: logger.debug("Skipping missing frame: '%s'", rotate_item) continue - self.alignments.rotate_existing_landmarks(rotate_item) + self.alignments.rotate_existing_landmarks(rotate_item, frame) def add_hashes(self, hashes, faces_dir): """ Add Face Hashes to the alignments file """ diff --git a/scripts/extract.py b/scripts/extract.py index e0a99f762d..cc1b081c1c 100644 --- a/scripts/extract.py +++ b/scripts/extract.py @@ -54,7 +54,7 @@ def process(self): self.verify_output) def threaded_io(self, task, io_args=None): - """ Load images in a background thread """ + """ Perform I/O task in a background thread """ logger.debug("Threading task: (Task: '%s')", task) io_args = tuple() if io_args is None else (io_args, ) if task == "load": @@ -211,7 +211,7 @@ def run_detection(self, to_process): self.threaded_io("reload", detected_faces) - def align_face(self, faces, align_eyes, size, filename, padding=48): + def align_face(self, faces, align_eyes, size, filename): """ Align the detected face and add the destination file path """ final_faces = list() image = faces["image"] @@ -221,11 +221,7 @@ def align_face(self, faces, align_eyes, size, filename, padding=48): detected_face = DetectedFace() detected_face.from_dlib_rect(face, image) detected_face.landmarksXY = landmarks[idx] - detected_face.frame_dims = image.shape[:2] - detected_face.load_aligned(image, - size=size, - padding=padding, - align_eyes=align_eyes) + detected_face.load_aligned(image, size=size, align_eyes=align_eyes) final_faces.append({"file_location": self.output_dir / Path(filename).stem, "face": detected_face}) faces["detected_faces"] = final_faces @@ -262,7 +258,7 @@ def __init__(self, arguments): logger.debug("Initialized %s", self.__class__.__name__) def set_parallel_processing(self): - """ Set whether to run detect and align together or seperately """ + """ Set whether to run detect and align together or separately """ detector_vram = self.detector.vram aligner_vram = self.aligner.vram gpu_stats = GPUStats() @@ -356,11 +352,6 @@ def launch_detector(self): kwargs = {"in_queue": queue_manager.get_queue("load"), "out_queue": out_queue} - if self.args.detector == "mtcnn": - mtcnn_kwargs = self.detector.validate_kwargs( - self.get_mtcnn_kwargs()) - kwargs["mtcnn_kwargs"] = mtcnn_kwargs - mp_func = PoolProcess if self.detector.parent_is_pool else SpawnProcess self.process_detect = mp_func(self.detector.run, **kwargs) @@ -384,14 +375,6 @@ def launch_detector(self): logger.debug("Launched Detector") - def get_mtcnn_kwargs(self): - """ Add the mtcnn arguments into a kwargs dictionary """ - mtcnn_threshold = [float(thr.strip()) - for thr in self.args.mtcnn_threshold] - return {"minsize": self.args.mtcnn_minsize, - "threshold": mtcnn_threshold, - "factor": self.args.mtcnn_scalefactor} - def detect_faces(self, extract_pass="detect"): """ Detect faces from in an image """ logger.debug("Running Detection. Pass: '%s'", extract_pass) diff --git a/scripts/fsmedia.py b/scripts/fsmedia.py index 53c665ae6a..8199afb252 100644 --- a/scripts/fsmedia.py +++ b/scripts/fsmedia.py @@ -180,7 +180,7 @@ def load(self): def load_disk_frames(self): """ Load frames from disk """ - logger.debug("Input is Seperate Frames. Loading images") + logger.debug("Input is separate Frames. Loading images") for filename in self.input_images: logger.trace("Loading image: '%s'", filename) try: @@ -314,9 +314,10 @@ def process(self, output_item): aligned_landmarks = face.aligned_landmarks resized_face = face.aligned_face size = face.aligned["size"] + padding = int(size * 0.1875) feature_mask = extractor.get_feature_mask( aligned_landmarks / size, - size, 48) + size, padding) feature_mask = cv2.blur( # pylint: disable=no-member feature_mask, (10, 10)) isolated_face = cv2.multiply( # pylint: disable=no-member diff --git a/scripts/gui.py b/scripts/gui.py index 6de7f69111..c4110f1479 100644 --- a/scripts/gui.py +++ b/scripts/gui.py @@ -1,101 +1,77 @@ #!/usr/bin python3 """ The optional GUI for faceswap """ -# NB: The GUI can't currently log as it is a wrapper for the python scripts, so don't -# implement logging unless you can handle the conflicts - +import logging import os import sys import tkinter as tk - from tkinter import messagebox, ttk -from lib.gui import (CliOptions, CurrentSession, CommandNotebook, Config, - ConsoleOut, DisplayNotebook, Images, ProcessWrapper, - StatusBar) +from lib.gui import (CliOptions, CommandNotebook, ConsoleOut, Session, DisplayNotebook, + get_config, get_images, initialize_images, initialize_config, MainMenuBar, + ProcessWrapper, StatusBar) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class FaceswapGui(tk.Tk): """ The Graphical User Interface """ def __init__(self, pathscript): - tk.Tk.__init__(self) - self.scaling_factor = self.get_scaling() + logger.debug("Initializing %s", self.__class__.__name__) + super().__init__() + + self.initialize_globals(pathscript) self.set_geometry() + self.wrapper = ProcessWrapper(pathscript) + get_images().delete_preview() + self.protocol("WM_DELETE_WINDOW", self.close_app) + logger.debug("Initialized %s", self.__class__.__name__) + + def initialize_globals(self, pathscript): + """ Initialize config and images global constants """ + cliopts = CliOptions() + scaling_factor = self.get_scaling() pathcache = os.path.join(pathscript, "lib", "gui", ".cache") - self.images = Images(pathcache) - self.cliopts = CliOptions() - self.session = CurrentSession() statusbar = StatusBar(self) - self.wrapper = ProcessWrapper(statusbar, - self.session, - pathscript, - self.cliopts) - - self.images.delete_preview() - self.protocol("WM_DELETE_WINDOW", self.close_app) + session = Session() + initialize_config(cliopts, scaling_factor, pathcache, statusbar, session) + initialize_images() def get_scaling(self): """ Get the display DPI """ dpi = self.winfo_fpixels("1i") - return dpi / 72.0 + scaling = dpi / 72.0 + logger.debug("dpi: %s, scaling: %s'", dpi, scaling) + return scaling def set_geometry(self): """ Set GUI geometry """ - self.tk.call("tk", "scaling", self.scaling_factor) - width = int(1200 * self.scaling_factor) - height = int(640 * self.scaling_factor) + scaling_factor = get_config().scaling_factor + self.tk.call("tk", "scaling", scaling_factor) + width = int(1200 * scaling_factor) + height = int(640 * scaling_factor) + logger.debug("Geometry: %sx%s", width, height) self.geometry("{}x{}+80+80".format(str(width), str(height))) def build_gui(self, debug_console): """ Build the GUI """ + logger.debug("Building GUI") self.title("Faceswap.py") - self.menu() + self.configure(menu=MainMenuBar(self)) topcontainer, bottomcontainer = self.add_containers() - CommandNotebook(topcontainer, - self.cliopts, - self.wrapper.tk_vars, - self.scaling_factor) - DisplayNotebook(topcontainer, - self.session, - self.wrapper.tk_vars, - self.scaling_factor) - ConsoleOut(bottomcontainer, debug_console, self.wrapper.tk_vars) - - def menu(self): - """ Menu bar for loading and saving configs """ - menubar = tk.Menu(self) - filemenu = tk.Menu(menubar, tearoff=0) - - config = Config(self.cliopts, self.wrapper.tk_vars) - - filemenu.add_command(label="Load full config...", - underline=0, - command=config.load) - filemenu.add_command(label="Save full config...", - underline=0, - command=config.save) - filemenu.add_separator() - filemenu.add_command(label="Reset all to default", - underline=0, - command=self.cliopts.reset) - filemenu.add_command(label="Clear all", - underline=0, - command=self.cliopts.clear) - filemenu.add_separator() - filemenu.add_command(label="Quit", - underline=0, - command=self.close_app) - - menubar.add_cascade(label="File", menu=filemenu, underline=0) - self.config(menu=menubar) + CommandNotebook(topcontainer) + DisplayNotebook(topcontainer) + ConsoleOut(bottomcontainer, debug_console) + logger.debug("Built GUI") def add_containers(self): """ Add the paned window containers that hold each main area of the gui """ + logger.debug("Adding containers") maincontainer = tk.PanedWindow(self, sashrelief=tk.RAISED, orient=tk.VERTICAL) @@ -109,21 +85,26 @@ def add_containers(self): bottomcontainer = ttk.Frame(maincontainer, height=150) maincontainer.add(bottomcontainer) + logger.debug("Added containers") return topcontainer, bottomcontainer def close_app(self): """ Close Python. This is here because the graph animation function continues to run even when tkinter has gone away """ + logger.debug("Close Requested") confirm = messagebox.askokcancel confirmtxt = "Processes are still running. Are you sure...?" - if (self.wrapper.tk_vars["runningtask"].get() + tk_vars = get_config().tk_vars + if (tk_vars["runningtask"].get() and not confirm("Close", confirmtxt)): + logger.debug("Close Cancelled") return - if self.wrapper.tk_vars["runningtask"].get(): + if tk_vars["runningtask"].get(): self.wrapper.task.terminate() - self.images.delete_preview() + get_images().delete_preview() self.quit() + logger.debug("Closed GUI") exit() diff --git a/scripts/train.py b/scripts/train.py index 06eabb0bcf..e2ff1507f4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,14 +4,18 @@ import logging import os import sys -import threading + +from threading import Lock +from time import sleep import cv2 import tensorflow as tf from keras.backend.tensorflow_backend import set_session -from lib.utils import (get_folder, get_image_paths, set_system_verbosity, - Timelapse) +from lib.keypress import KBHit +from lib.multithreading import MultiThread +from lib.queue_manager import queue_manager +from lib.utils import (get_folder, get_image_paths, set_system_verbosity) from plugins.plugin_loader import PluginLoader logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -20,38 +24,48 @@ class Train(): """ The training process. """ def __init__(self, arguments): + logger.debug("Initializing %s: (args: %s", self.__class__.__name__, arguments) self.args = arguments + self.timelapse = self.set_timelapse() self.images = self.get_images() self.stop = False self.save_now = False self.preview_buffer = dict() - self.lock = threading.Lock() + self.lock = Lock() - # this is so that you can enter case insensitive values for trainer - trainer_name = self.args.trainer - self.trainer_name = trainer_name - if trainer_name.lower() == "lowmem": - self.trainer_name = "LowMem" - self.timelapse = None + self.trainer_name = self.args.trainer + logger.debug("Initialized %s", self.__class__.__name__) - def process(self): - """ Call the training process object """ - logger.info("Training data directory: %s", self.args.model_dir) - set_system_verbosity(self.args.loglevel) - thread = self.start_thread() + def set_timelapse(self): + """ Set timelapse paths if requested """ + if (not self.args.timelapse_input_a and + not self.args.timelapse_input_b and + not self.args.timelapse_output): + return None + if not self.args.timelapse_input_a or not self.args.timelapse_input_b: + raise ValueError("To enable the timelapse, you have to supply " + "all the parameters (--timelapse-input-A and " + "--timelapse-input-B).") - if self.args.preview: - self.monitor_preview() - else: - self.monitor_console() + for folder in (self.args.timelapse_input_a, + self.args.timelapse_input_b, + self.args.timelapse_output): + if folder is not None and not os.path.isdir(folder): + raise ValueError("The Timelapse path '{}' does not exist".format(folder)) - self.end_thread(thread) + kwargs = {"input_a": self.args.timelapse_input_a, + "input_b": self.args.timelapse_input_b, + "output": self.args.timelapse_output} + logger.debug("Timelapse enabled: %s", kwargs) + return kwargs def get_images(self): """ Check the image dirs exist, contain images and return the image objects """ - images = [] - for image_dir in [self.args.input_A, self.args.input_B]: + logger.debug("Getting image paths") + images = dict() + for side in ("a", "b"): + image_dir = getattr(self.args, "input_{}".format(side)) if not os.path.isdir(image_dir): logger.error("Error: '%s' does not exist", image_dir) exit(1) @@ -60,30 +74,60 @@ def get_images(self): logger.error("Error: '%s' contains no images", image_dir) exit(1) - images.append(get_image_paths(image_dir)) - logger.info("Model A Directory: %s", self.args.input_A) - logger.info("Model B Directory: %s", self.args.input_B) + images[side] = get_image_paths(image_dir) + logger.info("Model A Directory: %s", self.args.input_a) + logger.info("Model B Directory: %s", self.args.input_b) + logger.debug("Got image paths: %s", [(key, str(len(val)) + " images") + for key, val in images.items()]) return images + def process(self): + """ Call the training process object """ + logger.debug("Starting Training Process") + logger.info("Training data directory: %s", self.args.model_dir) + set_system_verbosity(self.args.loglevel) + thread = self.start_thread() + # queue_manager.debug_monitor(1) + + if self.args.preview: + err = self.monitor_preview(thread) + else: + err = self.monitor_console(thread) + + self.end_thread(thread, err) + logger.debug("Completed Training Process") + def start_thread(self): """ Put the training process in a thread so we can keep control """ - thread = threading.Thread(target=self.process_thread) + logger.debug("Launching Trainer thread") + thread = MultiThread(target=self.training) thread.start() + logger.debug("Launched Trainer thread") return thread - def end_thread(self, thread): + def end_thread(self, thread, err): """ On termination output message and join thread back to main """ - logger.info("Exit requested! The trainer will complete its current cycle, " - "save the models and quit (it can take up a couple of seconds " - "depending on your training speed). If you want to kill it now, " - "press Ctrl + c") + logger.debug("Ending Training thread") + if err: + msg = "Error caught! Exiting..." + log = logger.critical + else: + msg = ("Exit requested! The trainer will complete its current cycle, " + "save the models and quit (it can take up a couple of seconds " + "depending on your training speed). If you want to kill it now, " + "press Ctrl + c") + log = logger.info + log(msg) self.stop = True thread.join() sys.stdout.flush() + logger.debug("Ended Training thread") - def process_thread(self): + def training(self): """ The training process to be run inside a thread """ try: + sleep(1) # Let preview instructions flush out to logger + logger.debug("Commencing Training") logger.info("Loading data, this may take a while...") if self.args.allow_growth: @@ -91,17 +135,12 @@ def process_thread(self): model = self.load_model() trainer = self.load_trainer(model) - - self.timelapse = Timelapse.create_timelapse( - self.args.timelapse_input_A, - self.args.timelapse_input_B, - self.args.timelapse_output, - trainer) - self.run_training_cycle(model, trainer) except KeyboardInterrupt: try: - model.save_weights() + logger.debug("Keyboard Interrupt Caught. Saving Weights and exiting") + model.save_models() + trainer.clear_tensorboard() except KeyboardInterrupt: logger.info("Saving model weights has been cancelled!") exit(0) @@ -110,105 +149,192 @@ def process_thread(self): def load_model(self): """ Load the model requested for training """ + logger.debug("Loading Model") model_dir = get_folder(self.args.model_dir) - model = PluginLoader.get_model(self.trainer_name)(model_dir, - self.args.gpus) - - model.load(swapped=False) + model = PluginLoader.get_model(self.trainer_name)( + model_dir, + self.args.gpus, + no_logs=self.args.no_logs, + warp_to_landmarks=self.args.warp_to_landmarks, + no_flip=self.args.no_flip, + training_image_size=self.image_size, + alignments_paths=self.alignments_paths, + preview_scale=self.args.preview_scale) + logger.debug("Loaded Model") return model + @property + def image_size(self): + """ Get the training set image size for storing in model data """ + image = cv2.imread(self.images["a"][0]) # pylint: disable=no-member + size = image.shape[0] + logger.debug("Training image size: %s", size) + return size + + @property + def alignments_paths(self): + """ Set the alignments path to input dirs if not provided """ + alignments_paths = dict() + for side in ("a", "b"): + alignments_path = getattr(self.args, "alignments_path_{}".format(side)) + if not alignments_path: + image_path = getattr(self.args, "input_{}".format(side)) + alignments_path = os.path.join(image_path, "alignments.json") + alignments_paths[side] = alignments_path + logger.debug("Alignments paths: %s", alignments_paths) + return alignments_paths + def load_trainer(self, model): """ Load the trainer requested for training """ - images_a, images_b = self.images - - trainer = PluginLoader.get_trainer(self.trainer_name) + logger.debug("Loading Trainer") + trainer = PluginLoader.get_trainer(model.trainer) trainer = trainer(model, - images_a, - images_b, - self.args.batch_size, - self.args.perceptual_loss) + self.images, + self.args.batch_size) + logger.debug("Loaded Trainer") return trainer def run_training_cycle(self, model, trainer): """ Perform the training cycle """ + logger.debug("Running Training Cycle") + if self.args.write_image or self.args.redirect_gui or self.args.preview: + display_func = self.show + else: + display_func = None + for iteration in range(0, self.args.iterations): + logger.trace("Training iteration: %s", iteration) save_iteration = iteration % self.args.save_interval == 0 - viewer = self.show if save_iteration or self.save_now else None - if save_iteration and self.timelapse is not None: - self.timelapse.work() - trainer.train_one_step(iteration, viewer) + viewer = display_func if save_iteration or self.save_now else None + timelapse = self.timelapse if save_iteration else None + trainer.train_one_step(viewer, timelapse) if self.stop: + logger.debug("Stop received. Terminating") break elif save_iteration: - model.save_weights() + logger.trace("Save Iteration: (iteration: %s", iteration) + model.save_models() elif self.save_now: - model.save_weights() + logger.trace("Save Requested: (iteration: %s", iteration) + model.save_models() self.save_now = False - model.save_weights() + logger.debug("Training cycle complete") + model.save_models() + trainer.clear_tensorboard() self.stop = True - def monitor_preview(self): + def monitor_preview(self, thread): """ Generate the preview window and wait for keyboard input """ - logger.info("Using live preview.\n" - "Press 'ENTER' on the preview window to save and quit.\n" - "Press 'S' on the preview window to save model weights " - "immediately") + logger.debug("Launching Preview Monitor") + logger.info("R|=====================================================================") + logger.info("R|- Using live preview -") + logger.info("R|- Press 'ENTER' on the preview window to save and quit -") + logger.info("R|- Press 'S' on the preview window to save model weights immediately -") + logger.info("R|=====================================================================") + err = False while True: try: with self.lock: for name, image in self.preview_buffer.items(): - cv2.imshow(name, image) + cv2.imshow(name, image) # pylint: disable=no-member - key = cv2.waitKey(1000) + key = cv2.waitKey(1000) # pylint: disable=no-member + if self.stop: + logger.debug("Stop received") + break + if thread.has_error: + logger.debug("Thread error detected") + err = True + break if key == ord("\n") or key == ord("\r"): + logger.debug("Exit requested") break if key == ord("s"): + logger.info("Save requested") self.save_now = True + except KeyboardInterrupt: + logger.debug("Keyboard Interrupt received") + break + logger.debug("Closed Preview Monitor") + return err + + def monitor_console(self, thread): + """ Monitor the console + NB: A custom function needs to be used for this because + input() blocks """ + logger.debug("Launching Console Monitor") + logger.info("R|===============================================") + logger.info("R|- Starting -") + logger.info("R|- Press 'ENTER' to save and quit -") + logger.info("R|- Press 'S' to save model weights immediately -") + logger.info("R|===============================================") + keypress = KBHit(is_gui=self.args.redirect_gui) + err = False + while True: + try: + if thread.has_error: + logger.debug("Thread error detected") + err = True + break if self.stop: + logger.debug("Stop received") break + if keypress.kbhit(): + key = keypress.getch() + if key in ("\n", "\r"): + logger.debug("Exit requested") + break + if key in ("s", "S"): + logger.info("Save requested") + self.save_now = True except KeyboardInterrupt: + logger.debug("Keyboard Interrupt received") break + keypress.set_normal_term() + logger.debug("Closed Console Monitor") + return err @staticmethod - def monitor_console(): - """ Monitor the console for any input followed by enter or ctrl+c """ - # TODO: how to catch a specific key instead of Enter? - # there isn't a good multiplatform solution: - # https://stackoverflow.com/questions/3523174 - # TODO: Find a way to interrupt input() if the target iterations are - # reached. At the moment, setting a target iteration and using the -p - # flag is the only guaranteed way to exit the training loop on - # hitting target iterations. - logger.info("Starting. Press 'ENTER' to stop training and save model") - try: - input() - except KeyboardInterrupt: - pass + def keypress_monitor(keypress_queue): + """ Monitor stdin for keypress """ + while True: + keypress_queue.put(sys.stdin.read(1)) @staticmethod def set_tf_allow_growth(): """ Allow TensorFlow to manage VRAM growth """ + # pylint: disable=no-member + logger.debug("Setting Tensorflow 'allow_growth' option") config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.visible_device_list = "0" set_session(tf.Session(config=config)) + logger.debug("Set Tensorflow 'allow_growth' option") def show(self, image, name=""): """ Generate the preview and write preview file output """ + logger.trace("Updating preview: (name: %s)", name) try: scriptpath = os.path.realpath(os.path.dirname(sys.argv[0])) if self.args.write_image: - img = "_sample_{}.jpg".format(name) + logger.trace("Saving preview to disk") + img = "training_preview.jpg" imgfile = os.path.join(scriptpath, img) - cv2.imwrite(imgfile, image) + cv2.imwrite(imgfile, image) # pylint: disable=no-member + logger.trace("Saved preview to: '%s'", img) if self.args.redirect_gui: - img = ".gui_preview_{}.jpg".format(name) + logger.trace("Generating preview for GUI") + img = ".gui_training_preview.jpg" imgfile = os.path.join(scriptpath, "lib", "gui", ".cache", "preview", img) - cv2.imwrite(imgfile, image) + cv2.imwrite(imgfile, image) # pylint: disable=no-member + logger.trace("Generated preview for GUI: '%s'", img) if self.args.preview: + logger.trace("Generating preview for display: '%s'", name) with self.lock: self.preview_buffer[name] = image + logger.trace("Generated preview for display: '%s'", name) except Exception as err: logging.error("could not preview sample") raise err + logger.trace("Updated preview: (name: %s)", name) diff --git a/tools/cli.py b/tools/cli.py index a67dcbf9c0..d243561424 100644 --- a/tools/cli.py +++ b/tools/cli.py @@ -2,7 +2,7 @@ """ Command Line Arguments for tools """ from lib.cli import FaceSwapArgs from lib.cli import (ContextFullPaths, DirFullPaths, - FileFullPaths, SaveFileFullPaths) + FileFullPaths, SaveFileFullPaths, Slider) from lib.utils import _image_extensions @@ -47,9 +47,9 @@ def get_argument_list(self): "\n\tfile." + output_opts + frames_dir + "\n'missing-frames': Identify frames in the alignments file that do no " "\n\tappear within the frames folder/video." + output_opts + frames_dir + - "\n'legacy': This updates legacy alignments to the latest format by adding" - "\n\tframe dimensions, rotating the landmarks and bounding boxes and adding" - "\n\tface_hashes" + frames_and_faces_dir + + "\n'legacy': This updates legacy alignments to the latest format by rotating" + "\n\tthe landmarks and bounding boxes and adding face_hashes." + + frames_and_faces_dir + "\n'leftover-faces': Identify faces in the faces folder that do not exist in" "\n\tthe alignments file." + output_opts + faces_dir + "\n'multi-faces': Identify where multiple faces exist within the alignments" @@ -123,6 +123,13 @@ def get_argument_list(self): "\n\tdirectory)." "\n'move': Move the discovered items to a sub-folder within the source" "\n\tdirectory."}) + argument_list.append({"opts": ("-sz", "--size"), + "type": int, + "action": Slider, + "min_max": (128, 512), + "default": 256, + "rounding": 64, + "help": "The output size of extracted faces. (extract only)"}) argument_list.append({"opts": ("-ae", "--align-eyes"), "action": "store_true", "dest": "align_eyes", @@ -409,6 +416,9 @@ def get_argument_list(): "Default: hist"}) argument_list.append({"opts": ('-t', '--ref_threshold'), + "action": Slider, + "min_max": (-1.0, 10.0), + "rounding": 2, "type": float, "dest": 'min_threshold', "default": -1.0, @@ -433,6 +443,9 @@ def get_argument_list(): "hist 0.3"}) argument_list.append({"opts": ('-b', '--bins'), + "action": Slider, + "min_max": (1, 100), + "rounding": 1, "type": int, "dest": 'num_bins', "default": 5, diff --git a/tools/lib_alignments/jobs.py b/tools/lib_alignments/jobs.py index 27ed0e2492..9710a4b1e7 100644 --- a/tools/lib_alignments/jobs.py +++ b/tools/lib_alignments/jobs.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -""" Tools for manipulating the alignments seralized file """ +""" Tools for manipulating the alignments serialized file """ import logging import os @@ -32,7 +32,7 @@ def __init__(self, alignments, arguments): logger.debug("Initialized %s", self.__class__.__name__) def get_source_dir(self, arguments): - """ Set the correct source dir """ + """ Set the correct source folder """ if hasattr(arguments, "faces_dir") and arguments.faces_dir: self.type = "faces" source_dir = arguments.faces_dir @@ -195,7 +195,7 @@ def move_frames(self, output_folder, items_output): os.rename(src, dst) def move_faces(self, output_folder, items_output): - """ Make additional subdirs for each face that appears + """ Make additional subfolders for each face that appears Enables easier manual sorting """ logger.info("Moving %s faces(s) to '%s'", len(items_output), output_folder) for frame, idx in items_output: @@ -239,7 +239,7 @@ def process(self): legacy.process() logger.info("[DRAW LANDMARKS]") # Tidy up cli output - self.extracted_faces = ExtractedFaces(self.frames, self.alignments, + self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256, align_eyes=self.arguments.align_eyes) frames_drawn = 0 for frame in tqdm(self.frames.file_list_sorted, desc="Drawing landmarks"): @@ -281,7 +281,7 @@ def __init__(self, alignments, arguments): self.type = arguments.job.replace("extract-", "") self.faces_dir = arguments.faces_dir self.frames = Frames(arguments.frames_dir) - self.extracted_faces = ExtractedFaces(self.frames, self.alignments, + self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=arguments.size, align_eyes=arguments.align_eyes) logger.debug("Initialized %s", self.__class__.__name__) @@ -352,8 +352,7 @@ def select_valid_faces(self, frame): valid_faces = faces else: sizes = self.extracted_faces.get_roi_size_for_frame(frame) - valid_faces = [faces[idx] - for idx, size in enumerate(sizes) + valid_faces = [faces[idx] for idx, size in enumerate(sizes) if size >= self.extracted_faces.size] logger.trace("frame: '%s', total_faces: %s, valid_faces: %s", frame, len(faces), len(valid_faces)) @@ -362,8 +361,6 @@ def select_valid_faces(self, frame): class Legacy(): """ Update legacy alignments: - - - Add frame dimensions - Rotate landmarks and bounding boxes on legacy alignments and remove the 'r' parameter - Add face hashes to alignments file @@ -383,16 +380,11 @@ def __init__(self, alignments, arguments, frames=None, faces=None, child_process def process(self): """ Run the rotate alignments process """ - no_dims = self.alignments.get_legacy_no_dims() rotated = self.alignments.get_legacy_rotation() hashes = self.alignments.get_legacy_no_hashes() - if (not self.frames or (not no_dims and not rotated)) and (not self.faces or not hashes): + if (not self.frames or not rotated) and (not self.faces or not hashes): return logger.info("[UPDATE LEGACY LANDMARKS]") # Tidy up cli output - if no_dims and self.frames: - logger.info("Legacy landmarks found. Adding frame dimensions...") - self.add_dimensions(no_dims) - self.alignments.save() if rotated and self.frames: logger.info("Legacy rotated frames found. Converting...") self.rotate_landmarks(rotated) @@ -402,20 +394,13 @@ def process(self): self.add_hashes(hashes) self.alignments.save() - def add_dimensions(self, no_dims): - """ Add width and height of original frame to alignments """ - for no_dim in tqdm(no_dims, desc="Adding Frame Dimensions"): - if no_dim not in self.frames.items.keys(): - continue - dims = self.frames.load_image(no_dim).shape[:2] - self.alignments.add_dimensions(no_dim, dims) - def rotate_landmarks(self, rotated): """ Rotate the landmarks """ for rotate_item in tqdm(rotated, desc="Rotating Landmarks"): - if rotate_item not in self.frames.items.keys(): + frame = self.frames.get(rotate_item, None) + if frame is None: continue - self.alignments.rotate_existing_landmarks(rotate_item) + self.alignments.rotate_existing_landmarks(rotate_item, frame) def add_hashes(self, hashes): """ Add Face Hashes to the alignments file """ @@ -838,19 +823,19 @@ def process(self): "alignments -j extract -a %s -fr -fc " "", self.arguments.alignments_file) - # define shape normalization utility functions + # Define shape normalization utility functions @staticmethod def normalize_shapes(shapes_im_coords): """ Normalize a 2D or 3D shape """ logger.debug("Normalize shapes") (num_pts, num_dims, _) = shapes_im_coords.shape - # calc mean coords and subtract from shapes + # Calculate mean coordinates and subtract from shapes mean_coords = shapes_im_coords.mean(axis=0) shapes_centered = np.zeros(shapes_im_coords.shape) shapes_centered = shapes_im_coords - np.tile(mean_coords, [num_pts, 1, 1]) - # calc scale factors and divide shapes + # Calculate scale factors and divide shapes scale_factors = np.sqrt((shapes_centered**2).sum(axis=1)).mean(axis=0) shapes_normalized = np.zeros(shapes_centered.shape) shapes_normalized = shapes_centered / np.tile(scale_factors, [num_pts, num_dims, 1]) @@ -889,12 +874,12 @@ def normalize(self): landmarks = np.array(val[0]["landmarksXY"]).reshape(68, 2, 1) start = end end = start + landmarks.shape[2] - # store in one big array + # Store in one big array landmarks_all[:, :, start:end] = landmarks - # make sure we keep track of the mapping to the original frame + # Make sure we keep track of the mapping to the original frame self.mappings[start] = key - # normalize shapes + # Normalize shapes normalized_shape = self.normalize_shapes(landmarks_all) self.normalized["landmarks"] = normalized_shape[0] self.normalized["scale_factors"] = normalized_shape[1] @@ -920,15 +905,15 @@ def spatially_filter(self): (project and reconstruct) """ logger.debug("Spatially Filter") landmarks_norm = self.normalized["landmarks"] - # convert to matrix form + # Convert to matrix form landmarks_norm_table = np.reshape(landmarks_norm, [68 * 2, landmarks_norm.shape[2]]).T - # project onto shapes model and reconstruct + # Project onto shapes model and reconstruct landmarks_norm_table_rec = self.shapes_model.inverse_transform( self.shapes_model.transform(landmarks_norm_table)) - # convert back to shapes (numKeypoint, num_dims, numFrames) + # Convert back to shapes (numKeypoint, num_dims, numFrames) landmarks_norm_rec = np.reshape(landmarks_norm_table_rec.T, [68, 2, landmarks_norm.shape[2]]) - # transform back to image coords + # Transform back to image coords retval = self.normalized_to_original(landmarks_norm_rec, self.normalized["scale_factors"], self.normalized["mean_coords"]) diff --git a/tools/lib_alignments/jobs_manual.py b/tools/lib_alignments/jobs_manual.py index 789d58917d..5428adab30 100644 --- a/tools/lib_alignments/jobs_manual.py +++ b/tools/lib_alignments/jobs_manual.py @@ -294,7 +294,7 @@ def get_skip_mode(self): def get_state_color(self): """ Return a color based on current state white - View Mode - yellow - Edit Mide + yellow - Edit Mode red - Unsaved alignments """ color = (255, 255, 255) if self.state["edit"]["updated"]: @@ -446,7 +446,7 @@ def process(self): legacy.process() logger.info("[MANUAL PROCESSING]") # Tidy up cli output - self.extracted_faces = ExtractedFaces(self.frames, self.alignments, + self.extracted_faces = ExtractedFaces(self.frames, self.alignments, size=256, align_eyes=self.align_eyes) self.interface = Interface(self.alignments, self.frames) self.help = Help(self.interface) @@ -510,8 +510,8 @@ def window_closed(self, is_windows, is_conda, key): MS Windows doesn't appear to read the window state property properly, so we check for a negative key press. - Conda (tested on Windows) doesn't sppear to read the window - state property or negative key press properly, so we arbitarily + Conda (tested on Windows) doesn't appear to read the window + state property or negative key press properly, so we arbitrarily use another property """ # pylint: disable=no-member logger.trace("Commencing closed window check") @@ -790,7 +790,7 @@ def init_extractor(self, loglevel): a_event = align_process.event align_process.start() - # Wait for Aligner to take init + # Wait for Aligner to initialize # The first ever load of the model for FAN has reportedly taken # up to 3-4 minutes, hence high timeout. a_event.wait(300) @@ -977,7 +977,8 @@ def update_landmarks(self): self.interface.state["edit"]["updated"] = True self.interface.state["edit"]["update_faces"] = True - def extracted_to_alignment(self, extract_data): + @staticmethod + def extracted_to_alignment(extract_data): """ Convert Extracted Tuple to Alignments data """ alignment = dict() d_rect, landmarks = extract_data @@ -985,6 +986,5 @@ def extracted_to_alignment(self, extract_data): alignment["w"] = d_rect.right() - d_rect.left() alignment["y"] = d_rect.top() alignment["h"] = d_rect.bottom() - d_rect.top() - alignment["frame_dims"] = self.media["image"].shape[:2] alignment["landmarksXY"] = landmarks return alignment diff --git a/tools/lib_alignments/media.py b/tools/lib_alignments/media.py index 5f683719e3..c8d3520bd1 100644 --- a/tools/lib_alignments/media.py +++ b/tools/lib_alignments/media.py @@ -53,7 +53,7 @@ def set_dfl(self, destination_format): self.set_destination_format(destination_format) def set_destination_format(self, destination_format): - """ Standardise the destination format to the correct extension """ + """ Standardize the destination format to the correct extension """ extensions = {".json": "json", ".p": "pickle", ".yml": "yaml", @@ -274,12 +274,11 @@ def sorted_items(self): class ExtractedFaces(): """ Holds the extracted faces and matrix for alignments """ - def __init__(self, frames, alignments, size=256, - padding=48, align_eyes=False): + def __init__(self, frames, alignments, size=256, align_eyes=False): logger.trace("Initializing %s: (size: %s, padding: %s, align_eyes: %s)", - self.__class__.__name__, size, padding, align_eyes) + self.__class__.__name__, size, align_eyes) self.size = size - self.padding = padding + self.padding = int(size * 0.1875) self.align_eyes = align_eyes self.alignments = alignments self.frames = frames @@ -309,10 +308,7 @@ def extract_one_face(self, alignment, image): self.current_frame, alignment) face = DetectedFace() face.from_alignment(alignment, image=image) - face.load_aligned(image, - size=self.size, - padding=self.padding, - align_eyes=self.align_eyes) + face.load_aligned(image, size=self.size, align_eyes=self.align_eyes) return face def get_faces_in_frame(self, frame, update=False): diff --git a/tools/sort.py b/tools/sort.py index 96af08de00..941a72253c 100644 --- a/tools/sort.py +++ b/tools/sort.py @@ -47,7 +47,7 @@ def process(self): # Assigning default threshold values based on grouping method if (self.args.final_process == "folders" - and self.args.min_threshold == -1.0): + and self.args.min_threshold < 0.0): method = self.args.group_method.lower() if method == 'face': self.args.min_threshold = 0.6 @@ -767,9 +767,9 @@ def estimate_blur(image_file): Normalize by pixel number to offset the effect of image size on pixel gradients & variance """ - image = cv2.imread(image_file,cv2.IMREAD_GRAYSCALE) + image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) blur_map = cv2.Laplacian(image, cv2.CV_32F) - score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1]) + score = np.var(blur_map) / np.sqrt(image.shape[0] * image.shape[1]) return score @staticmethod