Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
fix bias concat (#1296)
Browse files Browse the repository at this point in the history
Co-authored-by: Lin <[email protected]>
  • Loading branch information
eric-haibin-lin and Lin authored Aug 12, 2020
1 parent 528283d commit d75185e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/gluonnlp/model/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def hybrid_forward(self, F, qkv, valid_len, query_bias, key_bias, value_bias,
value_weight = value_weight.reshape(shape=(self._num_heads, -1, 0), reverse=True)
in_weight = F.concat(query_weight, key_weight, value_weight, dim=-2)
in_weight = in_weight.reshape(shape=(-1, 0), reverse=True)
in_bias = F.concat(query_bias, key_bias, value_bias, dim=0)
# concat bias
query_bias = query_bias.reshape(shape=(self._num_heads, -1), reverse=True)
key_bias = key_bias.reshape(shape=(self._num_heads, -1), reverse=True)
value_bias = value_bias.reshape(shape=(self._num_heads, -1), reverse=True)
in_bias = F.stack(query_bias, key_bias, value_bias, axis=1).reshape(-1)

# qkv_proj shape = (seq_length, batch_size, num_heads * head_dim * 3)
qkv_proj = F.FullyConnected(data=qkv, weight=in_weight, bias=in_bias,
Expand Down

0 comments on commit d75185e

Please sign in to comment.