Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error when launching reward model training with the weights produced in SFT stage #11

Open
javismiles opened this issue May 2, 2024 · 2 comments

Comments

@javismiles
Copy link

I followed your instructions for the SFT training and I got the .pt trained weights,
then I run:

python train_rm.py -b 2 -n experiment_name -p "./runs/sft_javSFT2_202405021344/sft_javSFT2_202405021344_step4000.pt

and I get this error in line 87 of train_rm.py

File "/xxxxxxxxxxxxxxxxxxxx/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPT:
Missing key(s) in state_dict: "transformer.decoder_blocks.0.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.0.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.0.mmsa.output_projection.lora_A", "transformer.decoder_blocks.0.mmsa.output_projection.lora_B", "transformer.decoder_blocks.0.ffn.fc1.lora_A", "transformer.decoder_blocks.0.ffn.fc1.lora_B", "transformer.decoder_blocks.0.ffn.fc2.lora_A", "transformer.decoder_blocks.0.ffn.fc2.lora_B", "transformer.decoder_blocks.1.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.1.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.1.mmsa.output_projection.lora_A", "transformer.decoder_blocks.1.mmsa.output_projection.lora_B", "transformer.decoder_blocks.1.ffn.fc1.lora_A", "transformer.decoder_blocks.1.ffn.fc1.lora_B", "transformer.decoder_blocks.1.ffn.fc2.lora_A", "transformer.decoder_blocks.1.ffn.fc2.lora_B", "transformer.decoder_blocks.2.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.2.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.2.mmsa.output_projection.lora_A", "transformer.decoder_blocks.2.mmsa.output_projection.lora_B", "transformer.decoder_blocks.2.ffn.fc1.lora_A", "transformer.decoder_blocks.2.ffn.fc1.lora_B", "transformer.decoder_blocks.2.ffn.fc2.lora_A", "transformer.decoder_blocks.2.ffn.fc2.lora_B", "transformer.decoder_blocks.3.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.3.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.3.mmsa.output_projection.lora_A", "transformer.decoder_blocks.3.mmsa.output_projection.lora_B", "transformer.decoder_blocks.3.ffn.fc1.lora_A", "transformer.decoder_blocks.3.ffn.fc1.lora_B", "transformer.decoder_blocks.3.ffn.fc2.lora_A", "transformer.decoder_blocks.3.ffn.fc2.lora_B", "transformer.decoder_blocks.4.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.4.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.4.mmsa.output_projection.lora_A", "transformer.decoder_blocks.4.mmsa.output_projection.lora_B", "transformer.decoder_blocks.4.ffn.fc1.lora_A", "transformer.decoder_blocks.4.ffn.fc1.lora_B", "transformer.decoder_blocks.4.ffn.fc2.lora_A", "transformer.decoder_blocks.4.ffn.fc2.lora_B", "transformer.decoder_blocks.5.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.5.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.5.mmsa.output_projection.lora_A", "transformer.decoder_blocks.5.mmsa.output_projection.lora_B", "transformer.decoder_blocks.5.ffn.fc1.lora_A", "transformer.decoder_blocks.5.ffn.fc1.lora_B", "transformer.decoder_blocks.5.ffn.fc2.lora_A", "transformer.decoder_blocks.5.ffn.fc2.lora_B", "transformer.decoder_blocks.6.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.6.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.6.mmsa.output_projection.lora_A", "transformer.decoder_blocks.6.mmsa.output_projection.lora_B", "transformer.decoder_blocks.6.ffn.fc1.lora_A", "transformer.decoder_blocks.6.ffn.fc1.lora_B", "transformer.decoder_blocks.6.ffn.fc2.lora_A", "transformer.decoder_blocks.6.ffn.fc2.lora_B", "transformer.decoder_blocks.7.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.7.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.7.mmsa.output_projection.lora_A", "transformer.decoder_blocks.7.mmsa.output_projection.lora_B", "transformer.decoder_blocks.7.ffn.fc1.lora_A", "transformer.decoder_blocks.7.ffn.fc1.lora_B", "transformer.decoder_blocks.7.ffn.fc2.lora_A", "transformer.decoder_blocks.7.ffn.fc2.lora_B", "transformer.decoder_blocks.8.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.8.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.8.mmsa.output_projection.lora_A", "transformer.decoder_blocks.8.mmsa.output_projection.lora_B", "transformer.decoder_blocks.8.ffn.fc1.lora_A", "transformer.decoder_blocks.8.ffn.fc1.lora_B", "transformer.decoder_blocks.8.ffn.fc2.lora_A", "transformer.decoder_blocks.8.ffn.fc2.lora_B", "transformer.decoder_blocks.9.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.9.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.9.mmsa.output_projection.lora_A", "transformer.decoder_blocks.9.mmsa.output_projection.lora_B", "transformer.decoder_blocks.9.ffn.fc1.lora_A", "transformer.decoder_blocks.9.ffn.fc1.lora_B", "transformer.decoder_blocks.9.ffn.fc2.lora_A", "transformer.decoder_blocks.9.ffn.fc2.lora_B", "transformer.decoder_blocks.10.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.10.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.10.mmsa.output_projection.lora_A", "transformer.decoder_blocks.10.mmsa.output_projection.lora_B", "transformer.decoder_blocks.10.ffn.fc1.lora_A", "transformer.decoder_blocks.10.ffn.fc1.lora_B", "transformer.decoder_blocks.10.ffn.fc2.lora_A", "transformer.decoder_blocks.10.ffn.fc2.lora_B", "transformer.decoder_blocks.11.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.11.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.11.mmsa.output_projection.lora_A", "transformer.decoder_blocks.11.mmsa.output_projection.lora_B", "transformer.decoder_blocks.11.ffn.fc1.lora_A", "transformer.decoder_blocks.11.ffn.fc1.lora_B", "transformer.decoder_blocks.11.ffn.fc2.lora_A", "transformer.decoder_blocks.11.ffn.fc2.lora_B", "transformer.decoder_blocks.12.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.12.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.12.mmsa.output_projection.lora_A", "transformer.decoder_blocks.12.mmsa.output_projection.lora_B", "transformer.decoder_blocks.12.ffn.fc1.lora_A", "transformer.decoder_blocks.12.ffn.fc1.lora_B", "transformer.decoder_blocks.12.ffn.fc2.lora_A", "transformer.decoder_blocks.12.ffn.fc2.lora_B", "transformer.decoder_blocks.13.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.13.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.13.mmsa.output_projection.lora_A", "transformer.decoder_blocks.13.mmsa.output_projection.lora_B", "transformer.decoder_blocks.13.ffn.fc1.lora_A", "transformer.decoder_blocks.13.ffn.fc1.lora_B", "transformer.decoder_blocks.13.ffn.fc2.lora_A", "transformer.decoder_blocks.13.ffn.fc2.lora_B", "transformer.decoder_blocks.14.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.14.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.14.mmsa.output_projection.lora_A", "transformer.decoder_blocks.14.mmsa.output_projection.lora_B", "transformer.decoder_blocks.14.ffn.fc1.lora_A", "transformer.decoder_blocks.14.ffn.fc1.lora_B", "transformer.decoder_blocks.14.ffn.fc2.lora_A", "transformer.decoder_blocks.14.ffn.fc2.lora_B", "transformer.decoder_blocks.15.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.15.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.15.mmsa.output_projection.lora_A", "transformer.decoder_blocks.15.mmsa.output_projection.lora_B", "transformer.decoder_blocks.15.ffn.fc1.lora_A", "transformer.decoder_blocks.15.ffn.fc1.lora_B", "transformer.decoder_blocks.15.ffn.fc2.lora_A", "transformer.decoder_blocks.15.ffn.fc2.lora_B", "transformer.decoder_blocks.16.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.16.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.16.mmsa.output_projection.lora_A", "transformer.decoder_blocks.16.mmsa.output_projection.lora_B", "transformer.decoder_blocks.16.ffn.fc1.lora_A", "transformer.decoder_blocks.16.ffn.fc1.lora_B", "transformer.decoder_blocks.16.ffn.fc2.lora_A", "transformer.decoder_blocks.16.ffn.fc2.lora_B", "transformer.decoder_blocks.17.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.17.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.17.mmsa.output_projection.lora_A", "transformer.decoder_blocks.17.mmsa.output_projection.lora_B", "transformer.decoder_blocks.17.ffn.fc1.lora_A", "transformer.decoder_blocks.17.ffn.fc1.lora_B", "transformer.decoder_blocks.17.ffn.fc2.lora_A", "transformer.decoder_blocks.17.ffn.fc2.lora_B", "transformer.decoder_blocks.18.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.18.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.18.mmsa.output_projection.lora_A", "transformer.decoder_blocks.18.mmsa.output_projection.lora_B", "transformer.decoder_blocks.18.ffn.fc1.lora_A", "transformer.decoder_blocks.18.ffn.fc1.lora_B", "transformer.decoder_blocks.18.ffn.fc2.lora_A", "transformer.decoder_blocks.18.ffn.fc2.lora_B", "transformer.decoder_blocks.19.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.19.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.19.mmsa.output_projection.lora_A", "transformer.decoder_blocks.19.mmsa.output_projection.lora_B", "transformer.decoder_blocks.19.ffn.fc1.lora_A", "transformer.decoder_blocks.19.ffn.fc1.lora_B", "transformer.decoder_blocks.19.ffn.fc2.lora_A", "transformer.decoder_blocks.19.ffn.fc2.lora_B", "transformer.decoder_blocks.20.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.20.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.20.mmsa.output_projection.lora_A", "transformer.decoder_blocks.20.mmsa.output_projection.lora_B", "transformer.decoder_blocks.20.ffn.fc1.lora_A", "transformer.decoder_blocks.20.ffn.fc1.lora_B", "transformer.decoder_blocks.20.ffn.fc2.lora_A", "transformer.decoder_blocks.20.ffn.fc2.lora_B", "transformer.decoder_blocks.21.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.21.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.21.mmsa.output_projection.lora_A", "transformer.decoder_blocks.21.mmsa.output_projection.lora_B", "transformer.decoder_blocks.21.ffn.fc1.lora_A", "transformer.decoder_blocks.21.ffn.fc1.lora_B", "transformer.decoder_blocks.21.ffn.fc2.lora_A", "transformer.decoder_blocks.21.ffn.fc2.lora_B", "transformer.decoder_blocks.22.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.22.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.22.mmsa.output_projection.lora_A", "transformer.decoder_blocks.22.mmsa.output_projection.lora_B", "transformer.decoder_blocks.22.ffn.fc1.lora_A", "transformer.decoder_blocks.22.ffn.fc1.lora_B", "transformer.decoder_blocks.22.ffn.fc2.lora_A", "transformer.decoder_blocks.22.ffn.fc2.lora_B", "transformer.decoder_blocks.23.mmsa.qkv_projection.lora_A", "transformer.decoder_blocks.23.mmsa.qkv_projection.lora_B", "transformer.decoder_blocks.23.mmsa.output_projection.lora_A", "transformer.decoder_blocks.23.mmsa.output_projection.lora_B", "transformer.decoder_blocks.23.ffn.fc1.lora_A", "transformer.decoder_blocks.23.ffn.fc1.lora_B", "transformer.decoder_blocks.23.ffn.fc2.lora_A", "transformer.decoder_blocks.23.ffn.fc2.lora_B", "lm_head.lora_A", "lm_head.lora_B".

@satel33
Copy link

satel33 commented May 31, 2024

@javismiles Please check the reward model for matching state dicts. That's using GPT.from_checkpoint(), that's different from GPTRewardModel. When I have updated it into GPTRewardModel, this error is fixed.

@ethanyanjiali would you please check this together? Is it intended?

@abhishek-sharma-iisc
Copy link

@javismiles Please check the reward model for matching state dicts. That's using GPT.from_checkpoint(), that's different from GPTRewardModel. When I have updated it into GPTRewardModel, this error is fixed.

@ethanyanjiali would you please check this together? Is it intended?

@satel33 Can you please explain more clearly how this error need to be tackled. I'm also getting the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants