Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port training to TF2 #283

Open
wants to merge 2 commits into
base: users/boomanaiden154/main.port-training-to-tf2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions gematria/model/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,6 @@ gematria_py_test(
name = "training_test",
size = "small",
srcs = ["training_test.py"],
tags = [
"manual",
],
deps = [
":training",
"//gematria/testing/python:basic_blocks_with_throughput",
Expand Down
25 changes: 10 additions & 15 deletions gematria/model/python/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def batches(


def partially_restore_from_checkpoint(
checkpoint_file: str, load_global_step_from_ckpt: bool, sess: tf.Session
checkpoint_file: str, load_step_from_ckpt: bool, model: tf.Module
) -> None:
"""Partially restores a checkpoint to the current graph.

Expand All @@ -311,29 +311,24 @@ def partially_restore_from_checkpoint(

Args:
checkpoint_file: A checkpoint to partially restore from.
load_global_step_from_ckpt: If True, load global step value from the given
checkpoint file.
sess: A TensorFlow session to restore into.
load_step_from_ckpt: If True, load the step value from the given checkpoint
file.
model: The tf.Module object representing the model that the weights should
be restored into.
"""
reader = tf.train.load_checkpoint(checkpoint_file)
shapes = reader.get_variable_to_shape_map()
dtypes = reader.get_variable_to_dtype_map()

if load_global_step_from_ckpt:
logging.info(
'Loading global step from checkpoint file: %s', checkpoint_file
)
global_step = tf.train.get_global_step()
global_step.load(reader.get_tensor('global_step'), sess)
if load_step_from_ckpt:
raise NotImplementedError()

for variable in tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=None
):
for variable in model.trainable_variables:
# All variable names should end with ':0'; this ':0' is not used in the
# checkpoint.
if not variable.name.endswith(':0'):
continue
variable_name = variable.name[:-2]
variable_name = variable.name[:-2] + '/.ATTRIBUTES/VARIABLE_VALUE'
if variable_name not in shapes:
logging.info('%s not found in the checkpoint', variable_name)
continue
Expand All @@ -354,7 +349,7 @@ def partially_restore_from_checkpoint(
)
continue
logging.info('Restoring %s', variable_name)
variable.load(reader.get_tensor(variable_name), sess)
variable.load(reader.get_tensor(variable_name))


def _as_list(values: Sequence[float]) -> list[float]:
Expand Down
105 changes: 39 additions & 66 deletions gematria/model/python/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gematria.model.python import training
from gematria.testing.python import basic_blocks_with_throughput
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf


class TrainingEpochStatsTest(tf.test.TestCase):
Expand Down Expand Up @@ -442,85 +442,58 @@ def test_both_limits(self, with_throughput):
self.assertSequenceEqual(batches, expected_batches)


class DummyModel(tf.Module):
"""A test model that contains some trainable variables."""

def __init__(
self, initial_value: int, var1_spec, var2_spec, var3_spec, var4_spec
):
self.var1 = tf.Variable(
tf.cast(tf.fill(var1_spec.shape, initial_value), dtype=var1_spec.dtype),
name='var1',
)
self.var2 = tf.Variable(
tf.cast(tf.fill(var2_spec.shape, initial_value), dtype=var2_spec.dtype),
name='var2',
)
self.var3 = tf.Variable(
tf.cast(tf.fill(var3_spec.shape, initial_value), dtype=var3_spec.dtype),
name='var3',
)
self.var4 = tf.Variable(
tf.cast(tf.fill(var4_spec.shape, initial_value), dtype=var4_spec.dtype),
name='var4',
)


class PartiallyRestoreFromCheckpointTest(tf.test.TestCase):

def test_partially_restore(self):
checkpoint_file = os.path.join(tf.test.get_temp_dir(), 'checkpoint')
v1_name = 'var_1'
checkpoint_folder = os.path.join(self.get_temp_dir(), 'checkpoint')
v1_spec = tf.TensorSpec((3,), dtype=tf.dtypes.int32)

v2_name = 'var_2'
v2_spec_a = tf.TensorSpec((2, 2), dtype=tf.dtypes.float32)
v2_spec_b = tf.TensorSpec((2, 2), dtype=tf.dtypes.int32)

v3_name = 'var_3'
v3_spec_a = tf.TensorSpec((1, 3), dtype=tf.dtypes.float32)
v3_spec_b = tf.TensorSpec((2, 1), dtype=tf.dtypes.float32)

v4_name = 'var_4'
v4_spec = tf.TensorSpec((3,), dtype=tf.dtypes.int32)
with tf.Graph().as_default():
initializer = tf.initializers.constant(1)
v1 = tf.get_variable(
name=v1_name,
shape=v1_spec.shape,
dtype=v1_spec.dtype,
initializer=initializer,
)
v2 = tf.get_variable(
name=v2_name,
shape=v2_spec_a.shape,
dtype=v2_spec_a.dtype,
initializer=initializer,
)
v3 = tf.get_variable(
name=v3_name,
shape=v3_spec_a.shape,
dtype=v3_spec_a.dtype,
initializer=initializer,
)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saved_filename = saver.save(sess, checkpoint_file, global_step=0)

with tf.Graph().as_default():
initializer = tf.initializers.constant(2)
v1 = tf.get_variable(
name=v1_name,
shape=v1_spec.shape,
dtype=v1_spec.dtype,
initializer=initializer,
)
v2 = tf.get_variable(
name=v2_name,
shape=v2_spec_b.shape,
dtype=v2_spec_b.dtype,
initializer=initializer,
)
v3 = tf.get_variable(
name=v3_name,
shape=v3_spec_b.shape,
dtype=v3_spec_b.dtype,
initializer=initializer,
)
v4 = tf.get_variable(
name=v4_name,
shape=v4_spec.shape,
dtype=v4_spec.dtype,
initializer=initializer,
)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
training.partially_restore_from_checkpoint(saved_filename, False, sess)

v1d, v2d, v3d, v4d = sess.run((v1, v2, v3, v4))
self.assertAllEqual(v1d, [1, 1, 1])
self.assertAllEqual(v2d, [[2, 2], [2, 2]])
self.assertAllEqual(v3d, [[2], [2]])
self.assertAllEqual(v4d, [2, 2, 2])
model_a = DummyModel(1, v1_spec, v2_spec_a, v3_spec_a, v4_spec)
checkpoint = tf.train.Checkpoint(model_a)
model_a_save_path = checkpoint.save(checkpoint_folder)

model_b = DummyModel(2, v1_spec, v2_spec_b, v3_spec_b, v4_spec)
training.partially_restore_from_checkpoint(
model_a_save_path, False, model_b
)

self.assertAllEqual(model_b.var1, [1, 1, 1])
self.assertAllEqual(model_b.var2, [[2, 2], [2, 2]])
self.assertAllEqual(model_b.var3, [[2], [2]])
self.assertAllEqual(model_b.var4, [1, 1, 1])


if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()
Loading