Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freeze layers for transfer learning #3247

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions training/deepspeech_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
from .util.checkpoints import drop_freeze_number_to_layers, load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
from .util.flags import create_flags, FLAGS
Expand Down Expand Up @@ -322,8 +322,24 @@ def get_tower_results(iterator, optimizer, dropout_rates):
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

# Filter out layers if we want to freeze some
if FLAGS.freeze_source_layers > 0:
filter_vars = drop_freeze_number_to_layers(FLAGS.freeze_source_layers, "freeze")
new_train_vars = list(train_vars)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to build up new_train_vars from empty, something like:

new_train_vars = []
for tv in train_vars:
    if tv.name not in filter_vars:
        new_train_vars.append(tv)
train_vars = new_train_vars

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemed more intuitive, we want to train all except the filtered layers.
Your example doesn't work by the way, because the filter_vars contain names like layer_1 and train_vars have the full layer name layer_1:dense:0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hence the "something like" :) -- You could of course use find() or something like that. I have no particularly strong feeling about it, but err on the side of simplicity.

for fv in filter_vars:
for tv in train_vars:
if fv in tv.name:
new_train_vars.remove(tv)
train_vars = new_train_vars
msg = "Tower {} - Training only variables: {}"
print(msg.format(i, [v.name for v in train_vars]))
else:
print("Tower {} - Training all layers".format(i))

# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like someone else to take a look at this.


# Retain tower's gradients
tower_gradients.append(gradients)
Expand Down Expand Up @@ -671,7 +687,6 @@ def __call__(self, progress, data, **kwargs):

print('-' * 80)


except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
Expand Down
65 changes: 42 additions & 23 deletions training/deepspeech_training/util/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
import tensorflow as tf

import tensorflow.compat.v1 as tfv1

from .flags import FLAGS
from .logging import log_info, log_error, log_warn
from .logging import log_error, log_info, log_warn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the order here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autosort?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what the DS policy is for that, I'd have to ask.



def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
Expand All @@ -19,32 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
# compatibility with older checkpoints.
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
if lr_var and ('learning_rate' not in vars_in_ckpt or
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spacing only change...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing: PEP 8: E127 continuation line over-indented for visual indent

assert len(lr_var) <= 1
load_vars -= lr_var
init_vars |= lr_var

if FLAGS.load_cudnn:
# Initialize training from a CuDNN RNN checkpoint
# Identify the variables which we cannot load, and set them
# for initialization
missing_vars = set()
for v in load_vars:
if v.op.name not in vars_in_ckpt:
log_warn('CUDNN variable not found: %s' % (v.op.name))
missing_vars.add(v)
# After training with "freeze_source_layers" the Adam moment tensors for the frozen layers
# are missing because they were not used. This might also occur when loading a cudnn checkpoint
# Therefore we have to initialize them again to continue training on such checkpoints
print_msg = False
for v in load_vars:
if v.op.name not in vars_in_ckpt:
if 'Adam' in v.name:
init_vars.add(v)
print_msg = True
if print_msg:
msg = "Some Adam tensors are missing, they will be initialized automatically."
log_info(msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not do just log_info("...") ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like

missing = []
if ...
missing.append(v)

if missing:
  for v in missing:
    log_info("Missing... {}".format(v))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split into two parts was an autoformatting issue, black would have split this into three lines.

I didn't print every missing layer, because there are messages later, that state which layers were reinitialized exactly.

load_vars -= init_vars

load_vars -= init_vars

# Check that the only missing variables (i.e. those to be initialised)
# are the Adam moment tensors, if they aren't then we have an issue
missing_var_names = [v.op.name for v in missing_vars]
if any('Adam' not in v for v in missing_var_names):
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1)
if FLAGS.load_cudnn:
# Check all required tensors are included in the cudnn checkpoint we want to load
for v in load_vars:
if v.op.name not in vars_in_ckpt and 'Adam' not in v.op.name:
msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \
' variable other than an Adam moment tensor: {}'
log_error(msg.format(v.op.name))
sys.exit(1)

if allow_drop_layers and FLAGS.drop_source_layers > 0:
# This transfer learning approach requires supplying
Expand All @@ -59,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
'dropping only 5 layers.')
FLAGS.drop_source_layers = 5

dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
dropped_layers = drop_freeze_number_to_layers(FLAGS.drop_source_layers, "drop")
# Initialize all variables needed for DS, but not loaded from ckpt
for v in load_vars:
if any(layer in v.op.name for layer in dropped_layers):
Expand All @@ -75,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
session.run(v.initializer)


def drop_freeze_number_to_layers(drop_freeze_number, mode):
""" Convert number of layers to drop or freeze into layer names """

if drop_freeze_number >= 6:
log_warn('The checkpoint only has 6 layers, but you are trying '
'to drop or freeze all of them or more. Continuing with 5 layers.')
drop_freeze_number = 5

layer_keys = ["layer_1", "layer_2", "layer_3", "lstm", "layer_5", "layer_6"]
if mode == "drop":
layer_keys = layer_keys[-1 * int(drop_freeze_number):]
elif mode == "freeze":
layer_keys = layer_keys[:-1 * int(drop_freeze_number)]
else:
raise ValueError
return layer_keys


def _checkpoint_path_or_none(checkpoint_filename):
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename)
if not checkpoint:
Expand Down
3 changes: 2 additions & 1 deletion training/deepspeech_training/util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def create_flags():

# Transfer Learning

f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output == 2, etc)')
f.DEFINE_integer('freeze_source_layers', 0, 'freeze layer weights (to freeze all but output == 1, freeze all but penultimate and output == 2, etc). Normally used in combination with "drop_source_layers" flag and should be used in a two step training (first drop and freeze layers and train a few epochs, second continue without both flags)')

# Exporting

Expand Down