-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LongNetTransformer Error #25
Comments
Having the same issue, python version 3.11. I believe it has to do with the skip connections implemented in the Transformer class' forward() function. The first ParallelTransformerBlock block has a default dilation rate of 2, meaning it will produce half the output tokens (256 vs 512 in normal transformers in this example). You can check this using the other example using DilatedAttention. This also means you cannot add the skip-connection because now there is a discrepancy between the sequence dimension of the original input and the dilated output. I'm not sure what the intention was of the original authors though. Steps to reproduce:
import torch
from long_net.model import LongNetTransformer
longnet = LongNetTransformer(
num_tokens=20000,
dim=512,
depth=6,
dim_head=64,
heads=8,
ff_mult=4,
)
tokens = torch.randint(0, 20000, (1, 512))
logits = longnet(tokens)
print(logits) Produces:
|
I ran the example program and got the following error.
It looks like there's something wrong internally?
Upvote & Fund
The text was updated successfully, but these errors were encountered: