From 386d7a7981d5ef89e9a68be840039d99dfeff5f4 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 14 Mar 2022 21:31:00 -0700 Subject: [PATCH] [Decoding] Update incremental decoding implementation --- src/gluonnlp/models/transformer.py | 49 +++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 646bcea808..9bd196f163 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -634,7 +634,7 @@ def init_states(self, batch_size, ctx, dtype='float32'): self._units // self._num_heads), ctx=ctx, dtype=dtype) return init_key, init_value - def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask=None): + def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask=None, prev_inter_kv=None): """Incrementally generate the output given the decoder input. Parameters @@ -670,6 +670,8 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= mem_attn_mask The attention mask between data and the memory Has shape (batch_size, 1, mem_length) + prev_inter_kv + The previously cached inter_attention keys, values Returns ------- @@ -716,13 +718,20 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= data = out if self._pre_norm: data = self.ln_inter(data) - out, _ = self.inter_attention(npx.reshape(self.attn_inter_q(data), - (-2, -2, self._num_heads, -1)), - npx.reshape(self.attn_inter_k(mem), - (-2, -2, self._num_heads, -1)), - npx.reshape(self.attn_inter_v(mem), - (-2, -2, self._num_heads, -1)), - mem_attn_mask) + query = npx.reshape(self.attn_inter_q(data), (-2, -2, self._num_heads, -1)) + if prev_inter_kv is None: + key = npx.reshape(self.attn_inter_k(mem), (-2, -2, self._num_heads, -1)) + value = npx.reshape(self.attn_inter_v(mem), (-2, -2, self._num_heads, -1)) + prev_inter_kv = {'prev_key': key, 'prev_value': value} + out, _ = self.inter_attention(query, + key, + value, + mem_attn_mask) + else: + out, _ = self.inter_attention(query, + prev_inter_kv['prev_key'], + prev_inter_kv['prev_value'], + mem_attn_mask) out = self.proj_inter(out) out = self.dropout_layer(out) out = out + data @@ -731,7 +740,7 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= # 3. Encode the output via an FFN layer out = self.ffn(out) out = npx.reshape(out, (-5, -1)) - return out, (new_key, new_value) + return out, (new_key, new_value), prev_inter_kv @use_np @@ -881,7 +890,7 @@ def init_states(self, batch_size, ctx, dtype='float32'): dtype=dtype)) return states - def incremental_decode(self, data, states, mem, mem_valid_length): + def incremental_decode(self, data, states, mem, mem_valid_length, prev_inter_kvs=None): """Incrementally generate the output given the decoder input. Parameters @@ -914,6 +923,9 @@ def incremental_decode(self, data, states, mem, mem_valid_length): Valid length of the memory Shape (batch_size,) + prev_inter_kvs + The previously cached inter_attention keys, values + Returns ------- out @@ -946,17 +958,23 @@ def incremental_decode(self, data, states, mem, mem_valid_length): # TODO(sxjscience) Try with boolean masking mem_attn_mask = mem_attn_mask.astype(self._dtype) new_states = [] + if prev_inter_kvs is None: + prev_inter_kvs = [None for _ in range(self.num_layers)] + new_inter_kvs = [] for i in range(self.num_layers): if self.recurrent: layer = self.layers[0] else: layer = self.layers[i] - out, new_state = layer.incremental_decode(out, states[i], - mem, mem_valid_length, mem_attn_mask) + out, new_state, prev_inter_kv = layer.incremental_decode(out, states[i], + mem, mem_valid_length, + mem_attn_mask, + prev_inter_kvs[i]) new_states.append(new_state) + new_inter_kvs.append(prev_inter_kv) if self._pre_norm: out = self.ln_final(out) - return out, new_states + return out, new_states, new_inter_kvs @use_np @@ -1364,6 +1382,7 @@ def __init__(self, model): """ super().__init__() self.model = model + self.cached_kvs = None def initialize(self, **kwargs): # Manually disable the initialize @@ -1464,9 +1483,9 @@ def forward(self, step_data, states): step_data = step_data * _np.sqrt(self.model.dec_units) if self.model.pos_embed_type is not None: step_data = step_data + self.model.tgt_pos_embed_layer(position) - out, new_states =\ + out, new_states, self.cached_kvs =\ self.model.decoder.incremental_decode(step_data, dec_states, - mem_data, mem_valid_length) + mem_data, mem_valid_length, self.cached_kvs) out = self.model.tgt_final_layer(out) return out, (mem_data, mem_valid_length, position + 1, new_states)