Skip to content

Commit

Permalink
Update code to support Tensorflow versions up to 2.8 (deepfakes#1213)
Browse files Browse the repository at this point in the history
* Update maximum tf version in setup + requirements

* - bump max version of tf version in launcher
- standardise tf version check

* update keras get_custom_objects  for tf>2.6

* bugfix: force black text in GUI file dialogs (linux)

* dssim loss - Move to stock tf.ssim function

* Update optimizer imports for compatibility

* fix logging for tf2.8

* Fix GUI graphing for TF2.8

* update tests

* bump requirements.txt versions

* Remove limit on nvidia-ml-py

* Graphing bugfixes
  - Prevent live graph from displaying if data not yet available

* bugfix: Live graph. Collect loss labels correctly

* fix: live graph - swallow inconsistent loss errors

* Bugfix: Prevent live graph from clearing during training

* Fix graphing for AMD
  • Loading branch information
torzdf authored May 2, 2022
1 parent cda49b3 commit c1512fd
Show file tree
Hide file tree
Showing 23 changed files with 553 additions and 506 deletions.
19 changes: 8 additions & 11 deletions _requirements_base.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
tqdm>=4.62
tqdm>=4.64
psutil>=5.8.0
numpy>=1.18.0,<1.20.0
opencv-python>=4.5.3.0
pillow>=8.3.1
scikit-learn>=0.24.2
fastcluster>=1.1.26
numpy>=1.18.0
opencv-python>=4.5.5.0
pillow>=9.0.1
scikit-learn>=1.0.2
fastcluster>=1.2.4
# matplotlib 3.3.1 breaks custom toolbar in graph popup
matplotlib>=3.2.0,<3.3.0
imageio>=2.9.0
imageio-ffmpeg>=0.4.5
imageio-ffmpeg>=0.4.7
ffmpy==0.2.3
# Exclude badly numbered Python2 version of nvidia-ml-py
# nvidia-ml-py>=11.450,<300
# v11.515.0 changes dtype of output items. Pinned for now
# TODO update code to use latest version
nvidia-ml-py>=11.450,<11.515
nvidia-ml-py>=11.510,<300
pywin32>=228 ; sys_platform == "win32"
pynvx==1.0.0 ; sys_platform == "darwin"
22 changes: 11 additions & 11 deletions lib/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from lib.gpu_stats import set_exclude_devices, GPUStats
from lib.logger import crash_log, log_setup
from lib.utils import (FaceswapError, get_backend, KerasFinder, safe_shutdown, set_backend,
set_system_verbosity)
from lib.utils import (FaceswapError, get_backend, get_tf_version, KerasFinder, safe_shutdown,
set_backend, set_system_verbosity)

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -41,7 +41,7 @@ def _import_script(self):
self._test_for_tf_version()
self._test_for_gui()
cmd = os.path.basename(sys.argv[0])
src = "tools.{}".format(self._command.lower()) if cmd == "tools.py" else "scripts"
src = f"tools.{self._command.lower()}" if cmd == "tools.py" else "scripts"
mod = ".".join((src, self._command.lower()))
module = import_module(mod)
script = getattr(module, self._command.title())
Expand All @@ -53,15 +53,15 @@ def _test_for_tf_version(self):
Raises
------
FaceswapError
If Tensorflow is not found, or is not between versions 2.2 and 2.6
If Tensorflow is not found, or is not between versions 2.2 and 2.8
"""
min_ver = 2.2
max_ver = 2.6
max_ver = 2.8
try:
# Ensure tensorflow doesn't pin all threads to one core when using Math Kernel Library
os.environ["TF_MIN_GPU_MULTIPROCESSOR_COUNT"] = "4"
os.environ["KMP_AFFINITY"] = "disabled"
import tensorflow as tf # pylint:disable=import-outside-toplevel
import tensorflow as tf # noqa pylint:disable=import-outside-toplevel,unused-import
except ImportError as err:
if "DLL load failed while importing" in str(err):
msg = (
Expand All @@ -77,14 +77,14 @@ def _test_for_tf_version(self):
f"error: {str(err)}")
self._handle_import_error(msg)

tf_ver = float(".".join(tf.__version__.split(".")[:2])) # pylint:disable=no-member
tf_ver = get_tf_version()
if tf_ver < min_ver:
msg = ("The minimum supported Tensorflow is version {} but you have version {} "
"installed. Please upgrade Tensorflow.".format(min_ver, tf_ver))
msg = (f"The minimum supported Tensorflow is version {min_ver} but you have version "
f"{tf_ver} installed. Please upgrade Tensorflow.")
self._handle_import_error(msg)
if tf_ver > max_ver:
msg = ("The maximum supported Tensorflow is version {} but you have version {} "
"installed. Please downgrade Tensorflow.".format(max_ver, tf_ver))
msg = (f"The maximum supported Tensorflow is version {max_ver} but you have version "
f"{tf_ver} installed. Please downgrade Tensorflow.")
self._handle_import_error(msg)
logger.debug("Installed Tensorflow Version: %s", tf_ver)

Expand Down
99 changes: 78 additions & 21 deletions lib/gui/analysis/event_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import numpy as np
import tensorflow as tf
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import errors_impl as tf_errors
from tensorflow.core.util import event_pb2 # pylint:disable=no-name-in-module
from tensorflow.python.framework import ( # pylint:disable=no-name-in-module
errors_impl as tf_errors)

from lib.serializer import get_serializer
from lib.utils import get_backend

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -43,7 +45,7 @@ def _get_log_filenames(self):
The full path of each log file for each training session id that has been run
"""
logger.debug("Loading log filenames. base_dir: '%s'", self._logs_folder)
retval = dict()
retval = {}
for dirpath, _, filenames in os.walk(self._logs_folder):
if not any(filename.startswith("events.out.tfevents") for filename in filenames):
continue
Expand Down Expand Up @@ -133,7 +135,7 @@ class _Cache():
def __init__(self, session_ids):
logger.debug("Initializing: %s: (session_ids: %s)", self.__class__.__name__, session_ids)
self._data = {idx: None for idx in session_ids}
self._carry_over = dict()
self._carry_over = {}
self._loss_labels = []
logger.debug("Initialized: %s", self.__class__.__name__)

Expand All @@ -158,18 +160,19 @@ def cache_data(self, session_id, data, labels, is_live=False):
"""
logger.debug("Caching event data: (session_id: %s, labels: %s, data points: %s, "
"is_live: %s)", session_id, labels, len(data), is_live)
if not data:
logger.debug("No data to cache")
return

if labels:
logger.debug("Setting loss labels: %s", labels)
self._loss_labels = labels

if not data:
logger.debug("No data to cache")
return

timestamps, loss = self._to_numpy(data, is_live)

if not is_live or (is_live and not self._data.get(session_id, None)):
self._data[session_id] = dict(labels=labels,
self._data[session_id] = dict(labels=self._loss_labels,
loss=zlib.compress(loss),
loss_shape=loss.shape,
timestamps=zlib.compress(timestamps),
Expand Down Expand Up @@ -207,10 +210,30 @@ def _to_numpy(self, data, is_live):
for idx in sorted(data)])
times, loss = self._process_data(data, times, loss, is_live)

times, loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
if is_live and not all(len(val) == len(self._loss_labels) for val in loss):
# TODO Many attempts have been made to fix this for live graph logging, and the issue
# of non-consistent loss record sizes keeps coming up. In the meantime we shall swallow
# any loss values that are of incorrect length so graph remains functional. This will,
# most likely, lead to a mismatch on iteration count so a proper fix should be
# implemented.

# Timestamps and loss appears to remain consistent with each other, but sometimes loss
# appears non-consistent. eg (lengths):
# [2, 2, 2, 2, 2, 2, 2, 0] - last loss collection has zero length
# [1, 2, 2, 2, 2, 2, 2, 2] - 1st loss collection has 1 length
# [2, 2, 2, 3, 2, 2, 2] - 4th loss collection has 3 length

logger.debug("Inconsistent loss found in collection: %s", loss)
for idx in reversed(range(len(loss))):
if len(loss[idx]) != len(self._loss_labels):
logger.debug("Removing loss/timestamps at position %s", idx)
del loss[idx]
del times[idx]

times, loss = (np.array(times, dtype="float64"), np.array(loss, dtype="float32"))
logger.debug("Converted to numpy: (data points: %s, timestamps shape: %s, loss shape: %s)",
len(data), times.shape, loss.shape)

return times, loss

def _collect_carry_over(self, data):
Expand Down Expand Up @@ -334,7 +357,7 @@ def get_data(self, session_id, metric):

dtype = "float32" if metric == "loss" else "float64"

retval = dict()
retval = {}
for idx, data in raw.items():
val = {metric: np.frombuffer(zlib.decompress(data[metric]),
dtype=dtype).reshape(data[f"{metric}_shape"])}
Expand Down Expand Up @@ -461,7 +484,7 @@ def get_loss(self, session_id=None):
and list of loss values for each step
"""
logger.debug("Getting loss: (session_id: %s)", session_id)
retval = dict()
retval = {}
for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx)
data = self._cache.get_data(idx, "loss")
Expand Down Expand Up @@ -493,7 +516,7 @@ def get_timestamps(self, session_id=None):

logger.debug("Getting timestamps: (session_id: %s, is_training: %s)",
session_id, self._is_training)
retval = dict()
retval = {}
for idx in [session_id] if session_id else self.session_ids:
self._check_cache(idx)
data = self._cache.get_data(idx, "timestamps")
Expand Down Expand Up @@ -565,16 +588,19 @@ def cache_events(self, session_id):
session_id: int
The session id that the data is being cached for
"""
data = dict()
data = {}
try:
for record in self._iterator:
event = event_pb2.Event.FromString(record) # pylint:disable=no-member
if not event.summary.value:
continue
if event.summary.value[0].tag == "keras":
self._parse_outputs(event)
if get_backend() == "amd":
# No model is logged for AMD so need to get loss labels from state file
self._add_amd_loss_labels(session_id)
if event.summary.value[0].tag.startswith("batch_"):
data[event.step] = self._process_event(event, data.get(event.step, dict()))
data[event.step] = self._process_event(event, data.get(event.step, {}))

except tf_errors.DataLossError as err:
logger.warning("The logs for Session %s are corrupted and cannot be displayed. "
Expand Down Expand Up @@ -605,10 +631,6 @@ def _parse_outputs(self, event):
config = serializer.unmarshal(struct)["config"]
model_outputs = self._get_outputs(config)

# loss length of unique should be 3:
# - decoder_both, 1, 2
# - docoder_a, decoder_b, 1
split_output = len(np.unique(model_outputs[..., :2])) != 3
for side_outputs, side in zip(model_outputs, ("a", "b")):
logger.debug("side: '%s', outputs: '%s'", side, side_outputs)
layer_name = side_outputs[0][0]
Expand All @@ -618,8 +640,10 @@ def _parse_outputs(self, event):
layer_outputs = self._get_outputs(output_config)
for output in layer_outputs: # Drill into sub-model to get the actual output names
loss_name = output[0][0]
if not split_output: # Rename losses to reflect the side's output
loss_name = f"{loss_name.replace('_both', '')}_{side}"
if loss_name[-2:] not in ("_a", "_b"): # Rename losses to reflect the side output
new_name = f"{loss_name.replace('_both', '')}_{side}"
logger.debug("Renaming loss output from '%s' to '%s'", loss_name, new_name)
loss_name = new_name
if loss_name not in self._loss_labels:
logger.debug("Adding loss name: '%s'", loss_name)
self._loss_labels.append(loss_name)
Expand Down Expand Up @@ -650,6 +674,28 @@ def _get_outputs(cls, model_config):
outputs, outputs.shape)
return outputs

def _add_amd_loss_labels(self, session_id):
""" It is not possible to store the model config in the Tensorboard logs for AMD so we
need to obtain the loss labels from the model's state file. This is called now so we know
event data is being written, and therefore the most current loss label data is available
in the state file.
Loss names are added to :attr:`_loss_labels`
Parameters
----------
session_id: int
The session id that the data is being cached for
"""
if self._cache._loss_labels: # pylint:disable=protected-access
return
# Import global session here to prevent circular import
from . import Session # pylint:disable=import-outside-toplevel
loss_labels = sorted(Session.get_loss_keys(session_id=session_id))
self._loss_labels = loss_labels
logger.debug("Collated loss labels: %s", self._loss_labels)

@classmethod
def _process_event(cls, event, step):
""" Process a single Tensorflow event.
Expand All @@ -670,8 +716,19 @@ def _process_event(cls, event, step):
The given step `dict` with the given event data added to it.
"""
summary = event.summary.value[0]

if summary.tag in ("batch_loss", "batch_total"): # Pre tf2.3 totals were "batch_total"
step["timestamp"] = event.wall_time
return step
step.setdefault("loss", list()).append(summary.simple_value)

loss = summary.simple_value
if not loss:
# Need to convert a tensor to a float for TF2.8 logged data. This maybe due to change
# in logging or may be due to work around put in place in FS training function for the
# following bug in TF 2.8 when writing records:
# https://github.com/keras-team/keras/issues/16173
loss = float(tf.make_ndarray(summary.tensor))

step.setdefault("loss", []).append(loss)

return step
Loading

0 comments on commit c1512fd

Please sign in to comment.