Skip to content

Commit

Permalink
Bugfix -Weights freezing/loading for dfl-sae
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Nov 9, 2021
1 parent 860ccb9 commit 086933d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
9 changes: 5 additions & 4 deletions plugins/train/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ class _Weights():
Parameters
----------
plugin: :class:`Model`
The parent plugin class that owns the IO functions.
The parent plugin class that owns the weights functions.
"""
def __init__(self, plugin):
logger.debug("Initializing %s: (plugin: %s)", self.__class__.__name__, plugin)
Expand Down Expand Up @@ -928,13 +928,14 @@ def _check_weights_file(cls, weights_file):

msg = ""
if not os.path.exists(weights_file):
msg = "Load weights selected, but the path '%s' does not exist."
msg = f"Load weights selected, but the path '{weights_file}' does not exist."
elif not os.path.splitext(weights_file)[-1].lower() == ".h5":
msg = "Load weights selected, but the path '%s' is not a valid Keras model (.h5) file."
msg = (f"Load weights selected, but the path '{weights_file}' is not a valid Keras "
f"model (.h5) file.")

if msg:
msg += " Please check and try again."
logger.error(msg)
raise FaceswapError(msg)

logger.verbose("Using weights file: %s", weights_file)
return weights_file
Expand Down
52 changes: 42 additions & 10 deletions plugins/train/model/dfl_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,31 @@

from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, ResidualBlock, UpscaleBlock

from ._base import ModelBase, KerasModel
from ._base import ModelBase, KerasModel, logger


class Model(ModelBase):
""" SAE Model from DFL """
def __init__(self, *args, **kwargs):

self._patch_weights_management(args[1])

super().__init__(*args, **kwargs)

self.input_shape = (self.config["input_size"], self.config["input_size"], 3)
self.architecture = self.config["architecture"].lower()
self.use_mask = self.config.get("learn_mask", False)
self.multiscale_count = 3 if self.config["multiscale_decoder"] else 1
self.encoder_dim = self.config["encoder_dims"]
self.decoder_dim = self.config["decoder_dims"]

@property
def name(self):
""" str: The name of this model based on the plugin name. Overridden as DFL-SAE is
named differently depending on the architecture selected. """
basename = super().name
name = f"{basename}_{self.architecture}"
return name

@property
def ae_dims(self):
""" Set the Autoencoder Dimensions or set to default """
Expand All @@ -31,6 +42,26 @@ def ae_dims(self):
retval = 256 if self.architecture == "liae" else 512
return retval

def _patch_weights_management(self, arguments):
""" Patch in the correct encoder name into the config dictionary for freezing and loading
weights based on architecture.
Because of variable model name based on architecture, configfile needs to be loaded
prior to initializing parent
Parameters
----------
arguments: :class:`argparse.Namespace`
The arguments that were passed to the train or convert process as generated from
Faceswap's command line arguments
"""
self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None
self.architecture = self.config["architecture"].lower()
self.config["freeze_layers"] = [f"encoder_{self.architecture}"]
self.config["load_layers"] = [f"encoder_{self.architecture}"]
logger.debug("Patched encoder layers to config: %s", self.config)

def build_model(self, inputs):
""" Build the DFL-SAE Model """
encoder = getattr(self, "encoder_{}".format(self.architecture))()
Expand All @@ -53,7 +84,7 @@ def build_model(self, inputs):
self.decoder("b", enc_output_shape)(encoder_b)]
autoencoder = KerasModel(inputs,
outputs,
name="{}_{}".format(self.name, self.architecture))
name=self.name)
return autoencoder

def encoder_df(self):
Expand Down Expand Up @@ -133,11 +164,12 @@ def decoder(self, side, input_shape):

def _legacy_mapping(self):
""" The mapping of legacy separate model names to single model names """
mappings = dict(df={"{}_encoder.h5".format(self.name): "encoder_df",
"{}_decoder_A.h5".format(self.name): "decoder_a",
"{}_decoder_B.h5".format(self.name): "decoder_b"},
liae={"{}_encoder.h5".format(self.name): "encoder_liae",
"{}_intermediate_B.h5".format(self.name): "intermediate_both",
"{}_intermediate.h5".format(self.name): "intermediate_b",
"{}_decoder.h5".format(self.name): "decoder_both"})
name = "dfl_sae"
mappings = dict(df={"{}_encoder.h5".format(name): "encoder_df",
"{}_decoder_A.h5".format(name): "decoder_a",
"{}_decoder_B.h5".format(name): "decoder_b"},
liae={"{}_encoder.h5".format(name): "encoder_liae",
"{}_intermediate_B.h5".format(name): "intermediate_both",
"{}_intermediate.h5".format(name): "intermediate_b",
"{}_decoder.h5".format(name): "decoder_both"})
return mappings[self.config["architecture"]]

0 comments on commit 086933d

Please sign in to comment.