Skip to content

Commit

Permalink
Revert "Bugfix -Weights freezing/loading for dfl-sae"
Browse files Browse the repository at this point in the history
This reverts commit 086933d.
  • Loading branch information
torzdf committed Nov 10, 2021
1 parent 086933d commit 8b7b125
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 47 deletions.
9 changes: 4 additions & 5 deletions plugins/train/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ class _Weights():
Parameters
----------
plugin: :class:`Model`
The parent plugin class that owns the weights functions.
The parent plugin class that owns the IO functions.
"""
def __init__(self, plugin):
logger.debug("Initializing %s: (plugin: %s)", self.__class__.__name__, plugin)
Expand Down Expand Up @@ -928,14 +928,13 @@ def _check_weights_file(cls, weights_file):

msg = ""
if not os.path.exists(weights_file):
msg = f"Load weights selected, but the path '{weights_file}' does not exist."
msg = "Load weights selected, but the path '%s' does not exist."
elif not os.path.splitext(weights_file)[-1].lower() == ".h5":
msg = (f"Load weights selected, but the path '{weights_file}' is not a valid Keras "
f"model (.h5) file.")
msg = "Load weights selected, but the path '%s' is not a valid Keras model (.h5) file."

if msg:
msg += " Please check and try again."
raise FaceswapError(msg)
logger.error(msg)

logger.verbose("Using weights file: %s", weights_file)
return weights_file
Expand Down
52 changes: 10 additions & 42 deletions plugins/train/model/dfl_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,20 @@

from lib.model.nn_blocks import Conv2DOutput, Conv2DBlock, ResidualBlock, UpscaleBlock

from ._base import ModelBase, KerasModel, logger
from ._base import ModelBase, KerasModel


class Model(ModelBase):
""" SAE Model from DFL """
def __init__(self, *args, **kwargs):

self._patch_weights_management(args[1])

super().__init__(*args, **kwargs)

self.input_shape = (self.config["input_size"], self.config["input_size"], 3)
self.architecture = self.config["architecture"].lower()
self.use_mask = self.config.get("learn_mask", False)
self.multiscale_count = 3 if self.config["multiscale_decoder"] else 1
self.encoder_dim = self.config["encoder_dims"]
self.decoder_dim = self.config["decoder_dims"]

@property
def name(self):
""" str: The name of this model based on the plugin name. Overridden as DFL-SAE is
named differently depending on the architecture selected. """
basename = super().name
name = f"{basename}_{self.architecture}"
return name

@property
def ae_dims(self):
""" Set the Autoencoder Dimensions or set to default """
Expand All @@ -42,26 +31,6 @@ def ae_dims(self):
retval = 256 if self.architecture == "liae" else 512
return retval

def _patch_weights_management(self, arguments):
""" Patch in the correct encoder name into the config dictionary for freezing and loading
weights based on architecture.
Because of variable model name based on architecture, configfile needs to be loaded
prior to initializing parent
Parameters
----------
arguments: :class:`argparse.Namespace`
The arguments that were passed to the train or convert process as generated from
Faceswap's command line arguments
"""
self._configfile = arguments.configfile if hasattr(arguments, "configfile") else None
self.architecture = self.config["architecture"].lower()
self.config["freeze_layers"] = [f"encoder_{self.architecture}"]
self.config["load_layers"] = [f"encoder_{self.architecture}"]
logger.debug("Patched encoder layers to config: %s", self.config)

def build_model(self, inputs):
""" Build the DFL-SAE Model """
encoder = getattr(self, "encoder_{}".format(self.architecture))()
Expand All @@ -84,7 +53,7 @@ def build_model(self, inputs):
self.decoder("b", enc_output_shape)(encoder_b)]
autoencoder = KerasModel(inputs,
outputs,
name=self.name)
name="{}_{}".format(self.name, self.architecture))
return autoencoder

def encoder_df(self):
Expand Down Expand Up @@ -164,12 +133,11 @@ def decoder(self, side, input_shape):

def _legacy_mapping(self):
""" The mapping of legacy separate model names to single model names """
name = "dfl_sae"
mappings = dict(df={"{}_encoder.h5".format(name): "encoder_df",
"{}_decoder_A.h5".format(name): "decoder_a",
"{}_decoder_B.h5".format(name): "decoder_b"},
liae={"{}_encoder.h5".format(name): "encoder_liae",
"{}_intermediate_B.h5".format(name): "intermediate_both",
"{}_intermediate.h5".format(name): "intermediate_b",
"{}_decoder.h5".format(name): "decoder_both"})
mappings = dict(df={"{}_encoder.h5".format(self.name): "encoder_df",
"{}_decoder_A.h5".format(self.name): "decoder_a",
"{}_decoder_B.h5".format(self.name): "decoder_b"},
liae={"{}_encoder.h5".format(self.name): "encoder_liae",
"{}_intermediate_B.h5".format(self.name): "intermediate_both",
"{}_intermediate.h5".format(self.name): "intermediate_b",
"{}_decoder.h5".format(self.name): "decoder_both"})
return mappings[self.config["architecture"]]

0 comments on commit 8b7b125

Please sign in to comment.