-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Synthetic copy task data question #127
Comments
It might suggest that SOS token is defined as "1". |
I think that the SOS token is 0 though? If we look at the final output example, the
Looking at that |
IMO, "0" as start_symbol is wrong. It needs to match with what is used for training, otherwise you won't get correct output for inference "0" is used for padding instead - see the last line of code in fun: data_gen()
|
I think just changing the above line from I don't think
|
"0" is good if no padding is involved, otherwise, i think they need to be defined separately. |
Ah, I see how I still don't think the current implementation is correct in how it's handling the SOS token because |
#116 ? |
I think the authors might not have noticed this issue because it is just a toy example demonstrating inference. I likely would have made the same mistake if I were the author. The data used in this example may be self-contradictory. Looking at the definition of data_gen below, I guess the authors intended to (arbitrarily) set 1 as the starting symbol for every sequence and 0 as the padding in Batch(src, tgt, 0).
To maintain this notion, the code in the function example_simple_model should ideally be changed from
into
That is, src starts with 1 and start_symbol is set as 1 for greedy_decode. The reason is that the original src starts with a padding token, which is unusual, and src_mark would be torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) instead of torch.ones(1, 1, max_len). I think the output of greedy_decode will be similar to src even without this change, because the model has been trained already, but it's best to have the fix to avoid confusion. |
I'm slightly confused by the line below data generator for the simple copy task.
annotated-transformer/the_annotated_transformer.py
Line 1279 in debc9fd
It seems to be manually setting part of the data to be 1. I'm not sure what the reason for this is, and it seems to be causing the first value in the output array to mess up sometimes, and removing it doesn't seem to hurt performance or break anything. What is this line intended to do?
Thanks for the help!
The text was updated successfully, but these errors were encountered: