Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot reproduce the results in the paper. #2

Open
yfzhang114 opened this issue Nov 17, 2022 · 2 comments
Open

Cannot reproduce the results in the paper. #2

yfzhang114 opened this issue Nov 17, 2022 · 2 comments

Comments

@yfzhang114
Copy link

yfzhang114 commented Nov 17, 2022

Hi!
It's a very interesting work and I try to reproduce the results in your paper. However, I run the code in the readme file with the default config, namely images_all_exemplars.py, and I get some unfamiliar results.

When I run $ python -m emergent_in_context_learning.experiment.experiment --config $PATH_TO_CONFIG --logtostderr --config.one_off_evaluate --config.restore_path $CKPT_DIR --jaxline_mode eval_fewshot_holdout

I1117 09:06:41.539124 140185420793280 data_generators.py:241] Zipf exponent: 0
I1117 09:06:41.539163 140185420793280 data_generators.py:242] Use Zipf for common/rare: False
I1117 09:06:41.539418 140185420793280 data_generators.py:243] Noise scale: 0
I1117 09:07:17.796816 140185420793280 data_generators.py:241] Zipf exponent: 0
I1117 09:07:17.796856 140185420793280 data_generators.py:242] Use Zipf for common/rare: False
I1117 09:07:17.796890 140185420793280 data_generators.py:243] Noise scale: 0
I1117 09:07:17.933558 140185420793280 utils.py:590] Returned checkpoint latest with id 0.
I1117 09:08:14.875151 140185420793280 experiment.py:552] [Step 500000] eval_loss=6.79, eval_accuracy=0.27

Considering the in-weight learning, $ python -m emergent_in_context_learning.experiment.experiment --config $PATH_TO_CONFIG --logtostderr --config.one_off_evaluate --config.restore_path $CKPT_DIR --jaxline_mode eval_no_support_zipfian

1117 09:19:25.652572 140016307090880 experiment.py:552] [Step 500000] eval_loss=1.37, eval_accuracy=0.63
I1117 09:19:25.652797 140016307090880 experiment.py:555] accuracy_closed: 0
I1117 09:19:25.655787 140016307090880 experiment.py:555] accuracy_interim: 0
I1117 09:19:25.658926 140016307090880 experiment.py:555] accuracy_query: 0
I1117 09:19:25.661794 140016307090880 experiment.py:555] from_common: 0
I1117 09:19:25.665280 140016307090880 experiment.py:555] from_fewshot: 0
I1117 09:19:25.667965 140016307090880 experiment.py:555] from_rare: 0
I1117 09:19:25.670737 140016307090880 experiment.py:555] from_support: 0
I1117 09:19:25.673430 140016307090880 experiment.py:555] from_support_common: 0
I1117 09:19:25.676214 140016307090880 experiment.py:555] from_support_fewshot: 0
I1117 09:19:25.679983 140016307090880 experiment.py:555] from_support_rare: 0
I1117 09:19:25.683503 140016307090880 experiment.py:555] loss: 1
I1117 09:19:25.686269 140016307090880 experiment.py:555] loss_interim: 0
I1117 09:19:25.689311 140016307090880 experiment.py:555] loss_query: 1
I1117 09:19:25.693018 140016307090880 experiment.py:555] last_label: 927
I1117 09:19:25.696013 140016307090880 experiment.py:555] last_prediction: 995
I1117 09:19:25.715014 140016307090880 train.py:200] Evaluated specific checkpoint, exiting.

the performance of in-weight learning is 0.63, which is higher than in-context learning.

@scychan
Copy link
Collaborator

scychan commented Dec 5, 2022

Hi Yifan,

Thanks for your question! The eval metrics that are reported in the paper are accuracy_query for in-weights learning (evaluated only for the query prediction) and accuracy_closed for in-context learning (i.e. only across the labels actually observed in context; see Sectopm 2.3 in the paper) -- what were the numbers for eval_fewshot_holdout?

Unfortunately, the transformer model that is available here is not exactly the one we run internally, because we have some internal dependencies that I wasn't able to opensource. This external implementation has not been fully tested, and may drive some discrepancies unfortunately. My recommendation is to incorporate the data generators into your favorite transformer train/eval framework. Hope that helps!

Best,
Stephanie

@chanb
Copy link

chanb commented Jun 13, 2024

@yfzhang114 not sure if you've figured this out. I was trying to reproduce this work on my end here. In short, I think what happens is that the batch norm stats used in ResNet is not synchronized among multiple devices. If we do synchronize it the result is very different.

Below is an experiment where I set P(bursty) = 1, with async meaning the former and sync being the latter.
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants