Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformer model based on Tensorlayer #1027

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ This release is compatible with TensorFlow 2 RC1.
- Support string dtype in InputLayer (#PR 1017)
- Support Dynamic RNN in RNN (#PR 1023)
- Add ResNet50 static model (#PR 1030)
- Add Transformer model (#PR 1027)
- Add performance test code in static model (#PR 1041)

### Changed
Expand Down Expand Up @@ -139,8 +140,8 @@ This release is compatible with TensorFlow 2 RC1.
- @luomai
- @ChrisWu1997: #1010 #1015 #1025 #1030 #1040
- @warshallrho: #1017 #1021 #1026 #1029 #1032 #1041
- @ArnoldLIULJ: #1023
- @JingqingZ: #1023
- @ArnoldLIULJ: #1023 #1027
- @JingqingZ: #1023 #1027

## [2.1.0]

Expand Down
6 changes: 6 additions & 0 deletions docs/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ TensorLayer provides many pretrained models, you can easily use the whole or a p
ResNet50
Seq2seq
Seq2seqLuongAttention
Transorformer


Base Model
Expand Down Expand Up @@ -57,3 +58,8 @@ Seq2seq Luong Attention
------------------------

.. autoclass:: Seq2seqLuongAttention

Transformer
------------------------

.. autoclass:: Transformer
157 changes: 157 additions & 0 deletions examples/translation_task/tutorial_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
import time
import numpy as np
import matplotlib.pyplot as plt
from tensorlayer.models.transformer import Transformer
from tensorlayer.models.transformer.utils import metrics
from tensorlayer.models.transformer.utils import attention_visualisation
import tensorlayer as tl
""" Translation from Portugese to English by Transformer model
This tutorial provides basic instructions on how to define and train Transformer model on Tensorlayer for
Translation task. You can also learn how to visualize the attention block via this tutorial.
"""


def set_up_dataset():
# Set up dataset for Portugese-English translation from the TED Talks Open Translation Project.
# This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples.
# https://www.ted.com/participate/translate

examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True, as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']

# Set up tokenizer and save the tokenizer
tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
(en.numpy() and pt.numpy() for pt, en in train_examples), target_vocab_size=2**14
)

tokenizer.save_to_file("tokenizer")
tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file("tokenizer")

return tokenizer, train_examples


def test_tokenizer_success(tokenizer):
sample_string = 'TensorLayer is awesome.'

tokenized_string = tokenizer.encode(sample_string)
print('Tokenized string is {}'.format(tokenized_string))

original_string = tokenizer.decode(tokenized_string)
print('The original string: {}'.format(original_string))
assert original_string == sample_string


def generate_training_dataset(train_examples, tokenizer):

def encode(lang1, lang2):
lang1 = tokenizer.encode(lang1.numpy()) + [tokenizer.vocab_size + 1]

lang2 = tokenizer.encode(lang2.numpy()) + [tokenizer.vocab_size + 1]

return lang1, lang2

MAX_LENGTH = 50

def filter_max_length(x, y, max_length=MAX_LENGTH):
return tf.logical_and(tf.size(x) <= max_length, tf.size(y) <= max_length)

def tf_encode(pt, en):
return tf.py_function(encode, [pt, en], [tf.int64, tf.int64])

train_dataset = train_examples.map(tf_encode)
train_dataset = train_dataset.filter(filter_max_length)
# cache the dataset to memory to get a speedup while reading from it.
train_dataset = train_dataset.cache()
BUFFER_SIZE = 20000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE, padded_shapes=([-1], [-1]))
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

return train_dataset


def model_setup(tokenizer):
# define Hyper parameters for transformer
class HYPER_PARAMS(object):
vocab_size = tokenizer.vocab_size + 10
encoder_num_layers = 4
decoder_num_layers = 4
hidden_size = 128
ff_size = 512
num_heads = 8
keep_prob = 0.9

# Default prediction params
extra_decode_length = 50
beam_size = 5
alpha = 0.6 # used to calculate length normalization in beam search

label_smoothing = 0.1
learning_rate = 2.0
learning_rate_decay_rate = 1.0
learning_rate_warmup_steps = 4000

sos_id = 0
eos_id = tokenizer.vocab_size + 1

model = Transformer(HYPER_PARAMS)

# Set the optimizer
learning_rate = CustomSchedule(HYPER_PARAMS.hidden_size, warmup_steps=HYPER_PARAMS.learning_rate_warmup_steps)
optimizer = tl.optimizers.LazyAdamOptimizer(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
return model, optimizer, HYPER_PARAMS


# Use the Adam optimizer with a custom learning rate scheduler according to the formula in the Paper "Attention is All you need"
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

def __init__(self, d_model, warmup_steps=5):
super(CustomSchedule, self).__init__()

self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)

self.warmup_steps = warmup_steps

def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps**-1.5)

return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


def tutorial_transformer():
tokenizer, train_examples = set_up_dataset()
train_dataset = generate_training_dataset(train_examples, tokenizer)
model, optimizer, HYPER_PARAMS = model_setup(tokenizer)

num_epochs = 10
for epoch in range(num_epochs):
model.train()
for (batch, (inp, tar)) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits, weights_encoder, weights_decoder = model(inputs=inp, targets=tar)
logits = metrics.MetricLayer(HYPER_PARAMS.vocab_size)([logits, tar])
logits, loss = metrics.LossLayer(HYPER_PARAMS.vocab_size, 0.1)([logits, tar])
grad = tape.gradient(loss, model.all_weights)
optimizer.apply_gradients(zip(grad, model.all_weights))
if (batch % 50 == 0):
print('Batch ID {} at Epoch [{}/{}]: loss {:.4f}'.format(batch, epoch + 1, num_epochs, loss))

model.eval()
sentence_en = tokenizer.encode('TensorLayer is awesome.')
[prediction, weights_decoder], weights_encoder = model(inputs=[sentence_en])

predicted_sentence = tokenizer.decode([i for i in prediction["outputs"][0] if i < tokenizer.vocab_size])
print("Translated: ", predicted_sentence)

# visualize the self attention
tokenizer_str = [tokenizer.decode([ts]) for ts in (sentence_en)]
attention_visualisation.plot_attention_weights(weights_encoder["layer_0"], tokenizer_str, tokenizer_str)


if __name__ == "__main__":
tutorial_transformer()
1 change: 1 addition & 0 deletions tensorlayer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .vgg import *
from .seq2seq import Seq2seq
from .seq2seq_with_attention import Seq2seqLuongAttention
from .transformer.transformer import Transformer
6 changes: 6 additions & 0 deletions tensorlayer/models/transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .attention_layer import *
from .transformer import Transformer
from .beamsearchHelper import *
from .feedforward_layer import *
from .embedding_layer import *
from .utils import *
156 changes: 156 additions & 0 deletions tensorlayer/models/transformer/attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of multiheaded attention and self-attention layers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorlayer as tl


class MultiHeadAttentionLayer(tl.layers.Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiHeadAttention is better than MultiHeadAttentionLayer?

"""The :class:`MultiHeadAttentionLayer` layer is for multi-head attention computation.
The weight computation is between "key" and "query", which will then matmul with "value" to generate information
that selectively focuses on the "query" messages.

Parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

-----------
num_heads : int
The number of heads which allow attention computation for different features
hidden_size : int
Out dim for the layer
keep_prob : float
Keep probablity for drop-out mechanism between 0 and 1
"""

def __init__(self, num_heads, hidden_size, keep_prob):

if hidden_size % num_heads:
raise ValueError(
"Hidden size ({}) must be divisible by the number of heads ({}).".format(hidden_size, num_heads)
)

super(MultiHeadAttentionLayer, self).__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = 1 - keep_prob

self.build(None)
self._built = True

def get_config(self):
return {
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
}

def build(self, inputs_shape):

# Transformation for linearly projecting the queries, keys, and values.
self.q_transformation = self._get_weights(
"q_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
)
self.v_transformation = self._get_weights(
"v_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
)
self.k_transformation = self._get_weights(
"k_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
)
self.out_transformation = self._get_weights(
"out_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
)

def split_heads(self, x):

with tf.name_scope("split_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]

# Calculate depth of last dimension after it has been split.
depth = (self.hidden_size // self.num_heads)

# Split the last dimension
x = tf.reshape(x, [batch_size, length, self.num_heads, depth])

# Transpose the result
return tf.transpose(x, [0, 2, 1, 3])

def combine_heads(self, x):

with tf.name_scope("combine_heads"):
batch_size = tf.shape(x)[0]
length = tf.shape(x)[2]
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])

def forward(self, x, y, mask, cache=None):
"""Apply attention mechanism to x and y."""
# Linearly project the query (q), key (k) and value (v) using different
# learned projections. This is in preparation of splitting them into
# multiple heads. Multi-head attention uses multiple queries, keys, and
# values rather than regular attention (which uses a single q, k, v).

v = k = y
q = x

q = tf.tensordot(q, self.q_transformation, axes=[[2], [0]])
k = tf.tensordot(k, self.k_transformation, axes=[[2], [0]])
v = tf.tensordot(v, self.v_transformation, axes=[[2], [0]])

if cache is not None:

# Combine cached keys and values with new keys and values.
k = tf.concat([cache["k"], k], axis=1)
v = tf.concat([cache["v"], v], axis=1)

# Update cache
cache["k"] = k
cache["v"] = v

# Split q, k, v into heads.
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v) #(Batch, num_head, length_v, dk)

# Scale q to prevent the dot product between q and k from growing too large.
depth = (self.hidden_size // self.num_heads)
q *= depth**-0.5

# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True) #(Batch, num_head, length_q, length_k)
logits += mask
weights = tf.nn.softmax(logits, name="attention_weights") #(Batch, num_head, length_q, length_k)
weights_store = weights
if self.is_train:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)

attention_output = tf.matmul(weights, v)

# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)

# Run the combined outputs through another linear projection layer.
attention_output = tf.tensordot(attention_output, self.out_transformation, axes=[[2], [0]])
return attention_output, weights_store


class SelfAttentionLayer(MultiHeadAttentionLayer):
"""Multiheaded self-attention layer."""

def forward(self, inputs, mask, cache=None):
return super(SelfAttentionLayer, self).forward(x=inputs, y=inputs, mask=mask, cache=cache)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .beam_search import *
Loading