-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformer model based on Tensorlayer #1027
base: master
Are you sure you want to change the base?
Changes from 1 commit
5b85af7
4d19a5a
d1a20df
2e44277
7717b74
1b349bb
1df422c
21161cb
412eadf
6ecca88
005ab91
8911654
61bf27f
3ef8d8b
048d9a3
a47aee1
3c4cae1
4d2e19e
f5438a7
a48e1d3
90d536e
e2662c2
e0e81f0
80c985c
2f316b0
990e014
2c1ced8
9144165
576af52
a2a1cbf
0670f1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .attention_layer import * | ||
from .transformer import Transformer | ||
from .beamsearchHelper import * | ||
from .feedforward_layer import * | ||
from .embedding_layer import * | ||
from .utils import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Implementation of multiheaded attention and self-attention layers.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
import tensorlayer as tl | ||
|
||
|
||
class MultiHeadAttentionLayer(tl.layers.Layer): | ||
"""Multi-headed attention layer.""" | ||
|
||
def __init__(self, num_heads, hidden_size, keep_prob): | ||
"""Initialize Attention. | ||
|
||
Args: | ||
hidden_size: int, output dim of hidden layer. | ||
num_heads: int, number of heads to repeat the same attention structure. | ||
keep_prob: float, keep rate for dropout mechanism inside attention for training. | ||
""" | ||
if hidden_size % num_heads: | ||
raise ValueError( | ||
"Hidden size ({}) must be divisible by the number of heads ({}).".format(hidden_size, num_heads) | ||
) | ||
|
||
super(MultiHeadAttentionLayer, self).__init__() | ||
self.hidden_size = hidden_size | ||
self.num_heads = num_heads | ||
self.attention_dropout = 1 - keep_prob | ||
|
||
self.build(None) | ||
self._built = True | ||
|
||
def get_config(self): | ||
return { | ||
"hidden_size": self.hidden_size, | ||
"num_heads": self.num_heads, | ||
"attention_dropout": self.attention_dropout, | ||
} | ||
|
||
def build(self, inputs_shape): | ||
# Transformation for linearly projecting the queries, keys, and values. | ||
self.q_transformation = self._get_weights( | ||
"q_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform') | ||
) | ||
self.v_transformation = self._get_weights( | ||
"v_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform') | ||
) | ||
self.k_transformation = self._get_weights( | ||
"k_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform') | ||
) | ||
self.out_transformation = self._get_weights( | ||
"out_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform') | ||
) | ||
|
||
def split_heads(self, x): | ||
"""Split x into different heads, and transpose the resulting value. | ||
|
||
The tensor is transposed to insure the inner dimensions hold the correct | ||
values during the matrix multiplication. | ||
|
||
Args: | ||
x: A tensor with shape [batch_size, length, hidden_size] | ||
|
||
Returns: | ||
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads] | ||
""" | ||
with tf.name_scope("split_heads"): | ||
batch_size = tf.shape(x)[0] | ||
length = tf.shape(x)[1] | ||
|
||
# Calculate depth of last dimension after it has been split. | ||
depth = (self.hidden_size // self.num_heads) | ||
|
||
# Split the last dimension | ||
x = tf.reshape(x, [batch_size, length, self.num_heads, depth]) | ||
|
||
# Transpose the result | ||
return tf.transpose(x, [0, 2, 1, 3]) | ||
|
||
def combine_heads(self, x): | ||
"""Combine tensor that has been split. | ||
|
||
Args: | ||
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads] | ||
|
||
Returns: | ||
A tensor with shape [batch_size, length, hidden_size] | ||
""" | ||
with tf.name_scope("combine_heads"): | ||
batch_size = tf.shape(x)[0] | ||
length = tf.shape(x)[2] | ||
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] | ||
return tf.reshape(x, [batch_size, length, self.hidden_size]) | ||
|
||
def forward(self, inputs, mask, cache=None): | ||
"""Apply attention mechanism to x and y. | ||
|
||
Args: | ||
x: a tensor with shape [batch_size, length_x, hidden_size] | ||
y: a tensor with shape [batch_size, length_y, hidden_size] | ||
mask: attention bias that will be added to the result of the dot product. | ||
training: boolean, whether in training mode or not. | ||
cache: (Used during prediction) dictionary with tensors containing results | ||
of previous attentions. The dictionary must have the items: | ||
{"k": tensor with shape [batch_size, i, key_channels], | ||
"v": tensor with shape [batch_size, i, value_channels]} | ||
where i is the current decoded length. | ||
|
||
Returns: | ||
Attention layer output with shape [batch_size, length_x, hidden_size] | ||
""" | ||
# Linearly project the query (q), key (k) and value (v) using different | ||
# learned projections. This is in preparation of splitting them into | ||
# multiple heads. Multi-head attention uses multiple queries, keys, and | ||
# values rather than regular attention (which uses a single q, k, v). | ||
|
||
if (len(inputs) == 2): | ||
q = inputs[0] | ||
k = v = inputs[1] | ||
|
||
if (len(inputs) == 3): | ||
q = inputs[0] | ||
k = inputs[1] | ||
v = inputs[2] | ||
|
||
q = tf.tensordot(q, self.q_transformation, axes=[[2], [0]]) | ||
k = tf.tensordot(k, self.k_transformation, axes=[[2], [0]]) | ||
v = tf.tensordot(v, self.v_transformation, axes=[[2], [0]]) | ||
|
||
if cache is not None: | ||
|
||
# Combine cached keys and values with new keys and values. | ||
k = tf.concat([cache["k"], k], axis=1) | ||
v = tf.concat([cache["v"], v], axis=1) | ||
|
||
# Update cache | ||
cache["k"] = k | ||
cache["v"] = v | ||
|
||
# Split q, k, v into heads. | ||
q = self.split_heads(q) | ||
k = self.split_heads(k) | ||
v = self.split_heads(v) #(Batch, num_head, length_v, dk) | ||
|
||
# Scale q to prevent the dot product between q and k from growing too large. | ||
depth = (self.hidden_size // self.num_heads) | ||
q *= depth**-0.5 | ||
|
||
# Calculate dot product attention | ||
logits = tf.matmul(q, k, transpose_b=True) #(Batch, num_head, length_q, length_k) | ||
logits += mask | ||
weights = tf.nn.softmax(logits, name="attention_weights") #(Batch, num_head, length_q, length_k) | ||
if self.is_train: | ||
weights = tf.nn.dropout(weights, rate=self.attention_dropout) | ||
|
||
attention_output = tf.matmul(weights, v) | ||
|
||
# Recombine heads --> [batch_size, length, hidden_size] | ||
attention_output = self.combine_heads(attention_output) | ||
|
||
# Run the combined outputs through another linear projection layer. | ||
attention_output = tf.tensordot(attention_output, self.out_transformation, axes=[[2], [0]]) | ||
return attention_output | ||
|
||
|
||
class SelfAttentionLayer(MultiHeadAttentionLayer): | ||
"""Multiheaded self-attention layer.""" | ||
|
||
def forward(self, inputs, mask, cache=None): | ||
return super(SelfAttentionLayer, self).forward(inputs=[inputs, inputs], mask=mask, cache=cache) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .beam_search import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Beam search in TF v2. | ||
""" | ||
|
||
import tensorflow as tf | ||
import tensorlayer.models.transformer.beamsearchHelper.beam_search_v1 as v1 | ||
|
||
_StateKeys = v1._StateKeys # pylint: disable=protected-access | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should avoid global variables in the library as much as possible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will try to optimise this one. |
||
|
||
|
||
class SequenceBeamSearchV2(v1.SequenceBeamSearch): | ||
"""Implementation of beam search loop in v2.""" | ||
|
||
def search(self, initial_ids, initial_cache): | ||
"""Beam search for sequences with highest scores.""" | ||
state, state_shapes = self._create_initial_state(initial_ids, initial_cache) | ||
finished_state = tf.while_loop( | ||
self._continue_search, self._search_step, loop_vars=[state], shape_invariants=[state_shapes], | ||
parallel_iterations=1, back_prop=False | ||
) | ||
finished_state = finished_state[0] | ||
|
||
alive_seq = finished_state[_StateKeys.ALIVE_SEQ] | ||
alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS] | ||
finished_seq = finished_state[_StateKeys.FINISHED_SEQ] | ||
finished_scores = finished_state[_StateKeys.FINISHED_SCORES] | ||
finished_flags = finished_state[_StateKeys.FINISHED_FLAGS] | ||
|
||
# Account for corner case where there are no finished sequences for a | ||
# particular batch item. In that case, return alive sequences for that batch | ||
# item. | ||
finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) | ||
finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) | ||
return finished_seq, finished_scores | ||
|
||
|
||
def sequence_beam_search( | ||
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, alpha, max_decode_length, eos_id | ||
): | ||
"""Search for sequence of subtoken ids with the largest probability. | ||
|
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. incorrect RST format, should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
symbols_to_logits_fn: A function that takes in ids, index, and cache as | ||
arguments. The passed in arguments will have shape: | ||
ids -> [batch_size * beam_size, index] | ||
index -> [] (scalar) | ||
cache -> nested dictionary of tensors [batch_size * beam_size, ...] | ||
The function must return logits and new cache. | ||
logits -> [batch * beam_size, vocab_size] | ||
new cache -> same shape/structure as inputted cache | ||
initial_ids: Starting ids for each batch item. | ||
int32 tensor with shape [batch_size] | ||
initial_cache: dict containing starting decoder variables information | ||
vocab_size: int size of tokens | ||
beam_size: int number of beams | ||
alpha: float defining the strength of length normalization | ||
max_decode_length: maximum length to decoded sequence | ||
eos_id: int id of eos token, used to determine when a sequence has finished | ||
|
||
Returns: | ||
Top decoded sequences [batch_size, beam_size, max_decode_length] | ||
sequence scores [batch_size, beam_size] | ||
""" | ||
batch_size = tf.shape(initial_ids)[0] | ||
|
||
sbs = SequenceBeamSearchV2( | ||
symbols_to_logits_fn, vocab_size, batch_size, beam_size, alpha, max_decode_length, eos_id | ||
) | ||
return sbs.search(initial_ids, initial_cache) | ||
|
||
|
||
def _expand_to_same_rank(tensor, target): | ||
"""Expands a given tensor to target's rank to be broadcastable. | ||
|
||
Args: | ||
tensor: input tensor to tile. Shape: [b, d1, ..., da] | ||
target: target tensor. Shape: [b, d1, ..., da, ..., dn] | ||
|
||
Returns: | ||
Tiled tensor of shape [b, d1, ..., da, 1, ..., 1] with same rank of target. | ||
|
||
Raises: | ||
ValueError, if the shape rank of rank tensor/target is None. | ||
""" | ||
if tensor.shape.rank is None: | ||
raise ValueError("Expect rank for tensor shape, but got None.") | ||
if target.shape.rank is None: | ||
raise ValueError("Expect rank for target shape, but got None.") | ||
|
||
with tf.name_scope("expand_rank"): | ||
diff_rank = target.shape.rank - tensor.shape.rank | ||
for _ in range(diff_rank): | ||
tensor = tf.expand_dims(tensor, -1) | ||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MultiHeadAttention is better than MultiHeadAttentionLayer?