-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformer model based on Tensorlayer #1027
Open
ArnoldLIULJ
wants to merge
31
commits into
master
Choose a base branch
from
transformer
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
5b85af7
transformer updated
4d19a5a
Merge branch 'master' into transformer
zsdonghao d1a20df
minor change
2e44277
Merge branch 'master' of https://github.com/tensorlayer/tensorlayer i…
7717b74
minor
1b349bb
Merge branch 'master' into transformer
zsdonghao 1df422c
merge
21161cb
adjust files
412eadf
Merge branch 'master' into transformer
ArnoldLIULJ 6ecca88
merge
005ab91
attention visualisation
8911654
add attention visualisation
61bf27f
optimizer update
3ef8d8b
fix
048d9a3
add attention visualisation
a47aee1
add attention visualisation
3c4cae1
add decoder part attention visualisation
4d2e19e
documentation
f5438a7
documentation
a48e1d3
documentation
90d536e
add examples
e2662c2
documentation
e0e81f0
documentation
80c985c
doc
2f316b0
doc
990e014
reverse change
2c1ced8
reverse change
9144165
doc
576af52
optimizer
a2a1cbf
doc
0670f1c
Merge branch 'master' into transformer
zsdonghao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
import tensorflow_datasets as tfds | ||
import tensorflow as tf | ||
import time | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from tensorlayer.models.transformer import Transformer | ||
from tensorlayer.models.transformer.utils import metrics | ||
from tensorlayer.models.transformer.utils import attention_visualisation | ||
import tensorlayer as tl | ||
""" Translation from Portugese to English by Transformer model | ||
This tutorial provides basic instructions on how to define and train Transformer model on Tensorlayer for | ||
Translation task. You can also learn how to visualize the attention block via this tutorial. | ||
""" | ||
|
||
|
||
def set_up_dataset(): | ||
# Set up dataset for Portugese-English translation from the TED Talks Open Translation Project. | ||
# This dataset contains approximately 50000 training examples, 1100 validation examples, and 2000 test examples. | ||
# https://www.ted.com/participate/translate | ||
|
||
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True, as_supervised=True) | ||
train_examples, val_examples = examples['train'], examples['validation'] | ||
|
||
# Set up tokenizer and save the tokenizer | ||
tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus( | ||
(en.numpy() and pt.numpy() for pt, en in train_examples), target_vocab_size=2**14 | ||
) | ||
|
||
tokenizer.save_to_file("tokenizer") | ||
tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file("tokenizer") | ||
|
||
return tokenizer, train_examples | ||
|
||
|
||
def test_tokenizer_success(tokenizer): | ||
sample_string = 'TensorLayer is awesome.' | ||
|
||
tokenized_string = tokenizer.encode(sample_string) | ||
print('Tokenized string is {}'.format(tokenized_string)) | ||
|
||
original_string = tokenizer.decode(tokenized_string) | ||
print('The original string: {}'.format(original_string)) | ||
assert original_string == sample_string | ||
|
||
|
||
def generate_training_dataset(train_examples, tokenizer): | ||
|
||
def encode(lang1, lang2): | ||
lang1 = tokenizer.encode(lang1.numpy()) + [tokenizer.vocab_size + 1] | ||
|
||
lang2 = tokenizer.encode(lang2.numpy()) + [tokenizer.vocab_size + 1] | ||
|
||
return lang1, lang2 | ||
|
||
MAX_LENGTH = 50 | ||
|
||
def filter_max_length(x, y, max_length=MAX_LENGTH): | ||
return tf.logical_and(tf.size(x) <= max_length, tf.size(y) <= max_length) | ||
|
||
def tf_encode(pt, en): | ||
return tf.py_function(encode, [pt, en], [tf.int64, tf.int64]) | ||
|
||
train_dataset = train_examples.map(tf_encode) | ||
train_dataset = train_dataset.filter(filter_max_length) | ||
# cache the dataset to memory to get a speedup while reading from it. | ||
train_dataset = train_dataset.cache() | ||
BUFFER_SIZE = 20000 | ||
BATCH_SIZE = 64 | ||
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE, padded_shapes=([-1], [-1])) | ||
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE) | ||
|
||
return train_dataset | ||
|
||
|
||
def model_setup(tokenizer): | ||
# define Hyper parameters for transformer | ||
class HYPER_PARAMS(object): | ||
vocab_size = tokenizer.vocab_size + 10 | ||
encoder_num_layers = 4 | ||
decoder_num_layers = 4 | ||
hidden_size = 128 | ||
ff_size = 512 | ||
num_heads = 8 | ||
keep_prob = 0.9 | ||
|
||
# Default prediction params | ||
extra_decode_length = 50 | ||
beam_size = 5 | ||
alpha = 0.6 # used to calculate length normalization in beam search | ||
|
||
label_smoothing = 0.1 | ||
learning_rate = 2.0 | ||
learning_rate_decay_rate = 1.0 | ||
learning_rate_warmup_steps = 4000 | ||
|
||
sos_id = 0 | ||
eos_id = tokenizer.vocab_size + 1 | ||
|
||
model = Transformer(HYPER_PARAMS) | ||
|
||
# Set the optimizer | ||
learning_rate = CustomSchedule(HYPER_PARAMS.hidden_size, warmup_steps=HYPER_PARAMS.learning_rate_warmup_steps) | ||
optimizer = tl.optimizers.LazyAdamOptimizer(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) | ||
return model, optimizer, HYPER_PARAMS | ||
|
||
|
||
# Use the Adam optimizer with a custom learning rate scheduler according to the formula in the Paper "Attention is All you need" | ||
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): | ||
|
||
def __init__(self, d_model, warmup_steps=5): | ||
super(CustomSchedule, self).__init__() | ||
|
||
self.d_model = d_model | ||
self.d_model = tf.cast(self.d_model, tf.float32) | ||
|
||
self.warmup_steps = warmup_steps | ||
|
||
def __call__(self, step): | ||
arg1 = tf.math.rsqrt(step) | ||
arg2 = step * (self.warmup_steps**-1.5) | ||
|
||
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) | ||
|
||
|
||
def tutorial_transformer(): | ||
tokenizer, train_examples = set_up_dataset() | ||
train_dataset = generate_training_dataset(train_examples, tokenizer) | ||
model, optimizer, HYPER_PARAMS = model_setup(tokenizer) | ||
|
||
num_epochs = 10 | ||
for epoch in range(num_epochs): | ||
model.train() | ||
for (batch, (inp, tar)) in enumerate(train_dataset): | ||
with tf.GradientTape() as tape: | ||
logits, weights_encoder, weights_decoder = model(inputs=inp, targets=tar) | ||
logits = metrics.MetricLayer(HYPER_PARAMS.vocab_size)([logits, tar]) | ||
logits, loss = metrics.LossLayer(HYPER_PARAMS.vocab_size, 0.1)([logits, tar]) | ||
grad = tape.gradient(loss, model.all_weights) | ||
optimizer.apply_gradients(zip(grad, model.all_weights)) | ||
if (batch % 50 == 0): | ||
print('Batch ID {} at Epoch [{}/{}]: loss {:.4f}'.format(batch, epoch + 1, num_epochs, loss)) | ||
|
||
model.eval() | ||
sentence_en = tokenizer.encode('TensorLayer is awesome.') | ||
[prediction, weights_decoder], weights_encoder = model(inputs=[sentence_en]) | ||
|
||
predicted_sentence = tokenizer.decode([i for i in prediction["outputs"][0] if i < tokenizer.vocab_size]) | ||
print("Translated: ", predicted_sentence) | ||
|
||
# visualize the self attention | ||
tokenizer_str = [tokenizer.decode([ts]) for ts in (sentence_en)] | ||
attention_visualisation.plot_attention_weights(weights_encoder["layer_0"], tokenizer_str, tokenizer_str) | ||
|
||
|
||
if __name__ == "__main__": | ||
tutorial_transformer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .attention_layer import * | ||
from .transformer import Transformer | ||
from .beamsearchHelper import * | ||
from .feedforward_layer import * | ||
from .embedding_layer import * | ||
from .utils import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Implementation of multiheaded attention and self-attention layers.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
import tensorlayer as tl | ||
|
||
|
||
class MultiHeadAttentionLayer(tl.layers.Layer): | ||
"""The :class:`MultiHeadAttentionLayer` layer is for multi-head attention computation. | ||
The weight computation is between "key" and "query", which will then matmul with "value" to generate information | ||
that selectively focuses on the "query" messages. | ||
|
||
Parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing space There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
----------- | ||
num_heads : int | ||
The number of heads which allow attention computation for different features | ||
hidden_size : int | ||
Out dim for the layer | ||
keep_prob : float | ||
Keep probablity for drop-out mechanism between 0 and 1 | ||
""" | ||
|
||
def __init__(self, num_heads, hidden_size, keep_prob): | ||
|
||
if hidden_size % num_heads: | ||
raise ValueError( | ||
"Hidden size ({}) must be divisible by the number of heads ({}).".format(hidden_size, num_heads) | ||
) | ||
|
||
super(MultiHeadAttentionLayer, self).__init__() | ||
self.hidden_size = hidden_size | ||
self.num_heads = num_heads | ||
self.attention_dropout = 1 - keep_prob | ||
|
||
self.build(None) | ||
self._built = True | ||
|
||
def get_config(self): | ||
return { | ||
"hidden_size": self.hidden_size, | ||
"num_heads": self.num_heads, | ||
"attention_dropout": self.attention_dropout, | ||
} | ||
|
||
def build(self, inputs_shape): | ||
|
||
# Transformation for linearly projecting the queries, keys, and values. | ||
self.q_transformation = self._get_weights( | ||
"q_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform') | ||
) | ||
self.v_transformation = self._get_weights( | ||
"v_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform') | ||
) | ||
self.k_transformation = self._get_weights( | ||
"k_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform') | ||
) | ||
self.out_transformation = self._get_weights( | ||
"out_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform') | ||
) | ||
|
||
def split_heads(self, x): | ||
|
||
with tf.name_scope("split_heads"): | ||
batch_size = tf.shape(x)[0] | ||
length = tf.shape(x)[1] | ||
|
||
# Calculate depth of last dimension after it has been split. | ||
depth = (self.hidden_size // self.num_heads) | ||
|
||
# Split the last dimension | ||
x = tf.reshape(x, [batch_size, length, self.num_heads, depth]) | ||
|
||
# Transpose the result | ||
return tf.transpose(x, [0, 2, 1, 3]) | ||
|
||
def combine_heads(self, x): | ||
|
||
with tf.name_scope("combine_heads"): | ||
batch_size = tf.shape(x)[0] | ||
length = tf.shape(x)[2] | ||
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] | ||
return tf.reshape(x, [batch_size, length, self.hidden_size]) | ||
|
||
def forward(self, x, y, mask, cache=None): | ||
"""Apply attention mechanism to x and y.""" | ||
# Linearly project the query (q), key (k) and value (v) using different | ||
# learned projections. This is in preparation of splitting them into | ||
# multiple heads. Multi-head attention uses multiple queries, keys, and | ||
# values rather than regular attention (which uses a single q, k, v). | ||
|
||
v = k = y | ||
q = x | ||
|
||
q = tf.tensordot(q, self.q_transformation, axes=[[2], [0]]) | ||
k = tf.tensordot(k, self.k_transformation, axes=[[2], [0]]) | ||
v = tf.tensordot(v, self.v_transformation, axes=[[2], [0]]) | ||
|
||
if cache is not None: | ||
|
||
# Combine cached keys and values with new keys and values. | ||
k = tf.concat([cache["k"], k], axis=1) | ||
v = tf.concat([cache["v"], v], axis=1) | ||
|
||
# Update cache | ||
cache["k"] = k | ||
cache["v"] = v | ||
|
||
# Split q, k, v into heads. | ||
q = self.split_heads(q) | ||
k = self.split_heads(k) | ||
v = self.split_heads(v) #(Batch, num_head, length_v, dk) | ||
|
||
# Scale q to prevent the dot product between q and k from growing too large. | ||
depth = (self.hidden_size // self.num_heads) | ||
q *= depth**-0.5 | ||
|
||
# Calculate dot product attention | ||
logits = tf.matmul(q, k, transpose_b=True) #(Batch, num_head, length_q, length_k) | ||
logits += mask | ||
weights = tf.nn.softmax(logits, name="attention_weights") #(Batch, num_head, length_q, length_k) | ||
weights_store = weights | ||
if self.is_train: | ||
weights = tf.nn.dropout(weights, rate=self.attention_dropout) | ||
|
||
attention_output = tf.matmul(weights, v) | ||
|
||
# Recombine heads --> [batch_size, length, hidden_size] | ||
attention_output = self.combine_heads(attention_output) | ||
|
||
# Run the combined outputs through another linear projection layer. | ||
attention_output = tf.tensordot(attention_output, self.out_transformation, axes=[[2], [0]]) | ||
return attention_output, weights_store | ||
|
||
|
||
class SelfAttentionLayer(MultiHeadAttentionLayer): | ||
"""Multiheaded self-attention layer.""" | ||
|
||
def forward(self, inputs, mask, cache=None): | ||
return super(SelfAttentionLayer, self).forward(x=inputs, y=inputs, mask=mask, cache=cache) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .beam_search import * |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MultiHeadAttention is better than MultiHeadAttentionLayer?