Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic copy task data question #127

Open
cmmcirvin opened this issue Sep 22, 2024 · 8 comments
Open

Synthetic copy task data question #127

cmmcirvin opened this issue Sep 22, 2024 · 8 comments

Comments

@cmmcirvin
Copy link

I'm slightly confused by the line below data generator for the simple copy task.

It seems to be manually setting part of the data to be 1. I'm not sure what the reason for this is, and it seems to be causing the first value in the output array to mess up sometimes, and removing it doesn't seem to hurt performance or break anything. What is this line intended to do?

Thanks for the help!

@hhxxttxs-tang
Copy link

It might suggest that SOS token is defined as "1".

@cmmcirvin
Copy link
Author

I think that the SOS token is 0 though? If we look at the final output example, the start_token parameter passed into the greedy_decode function is 0.

print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))

Looking at that greedy_decode function, it seems like that start_symbol token is being used as the first output label, where I'd expect the SOS token to be.

@hhxxttxs-tang
Copy link

hhxxttxs-tang commented Sep 23, 2024

IMO, "0" as start_symbol is wrong. It needs to match with what is used for training, otherwise you won't get correct output for inference

"0" is used for padding instead - see the last line of code in fun: data_gen()

    yield Batch(src.to(device), tgt.to(device), 0) 

@cmmcirvin
Copy link
Author

I think just changing the above line from data[:, 0] = 1 to data[:, 0] = 0 should be fine, correct? It seems to be more stable when I do that, at least - I think if we use 0 as a start symbol everywhere by modifying the above line, this should be fine.

I don't think 1 is a good SOS token, as our data for the copy task ranges from 1 to V (line below), so I don't think it makes sense for the SOS token to be the same as a valid token we want to copy.

data = torch.randint(1, V, size=(batch_size, 10))

@hhxxttxs-tang
Copy link

hhxxttxs-tang commented Sep 23, 2024

"0" is good if no padding is involved, otherwise, i think they need to be defined separately.
not sure what's the best practice for a good SOS token.

@cmmcirvin
Copy link
Author

Ah, I see how 0 is being used for padding now, thanks. I was mis-understanding that part before.

I still don't think the current implementation is correct in how it's handling the SOS token because 1 is a valid element of the dataset, but I see why 0 would also be a bad idea now. Thanks!

@kuraga
Copy link

kuraga commented Sep 24, 2024

#116 ?

@PangLuo
Copy link

PangLuo commented Dec 9, 2024

I think the authors might not have noticed this issue because it is just a toy example demonstrating inference. I likely would have made the same mistake if I were the author. The data used in this example may be self-contradictory. Looking at the definition of data_gen below, I guess the authors intended to (arbitrarily) set 1 as the starting symbol for every sequence and 0 as the padding in Batch(src, tgt, 0).

def data_gen(V, batch_size, nbatches):
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch_size, 10))
        data[:, 0] = 1
        src = data.requires_grad_(False).clone().detach()
        tgt = data.requires_grad_(False).clone().detach()
        yield Batch(src, tgt, 0)

To maintain this notion, the code in the function example_simple_model should ideally be changed from

    src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len)
    print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))

into

    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 5]]) # just an example
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len)
    print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=1))

That is, src starts with 1 and start_symbol is set as 1 for greedy_decode. The reason is that the original src starts with a padding token, which is unusual, and src_mark would be torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1]) instead of torch.ones(1, 1, max_len). I think the output of greedy_decode will be similar to src even without this change, because the model has been trained already, but it's best to have the fix to avoid confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants