A simple clean-readable and shape-annotated implementation of Attention is All You Need in PyTorch. A sample onnx file can be found in assets/transformer.onnx
for visualization purposes.
It was tested on synthetic data, try to use the attention plots to figure out the transformation used to create the data!
- Positional Embeddings not included, similar to
nn.Transformer
but you can find an implementation inusage.ipynb
. - Parallel
MultiHeadAttention
outperforms the for loop implementation significantly, as expected. - Assumes
batch_first=True
input by default and cna't be changed. - Uses
einsum
for attention computation rather thanbmm
for readability, this might impact performance.