Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] Error(s) in loading state_dict for Linear #115

Open
Hamana0509 opened this issue Jun 25, 2024 · 2 comments
Open

[P1] Error(s) in loading state_dict for Linear #115

Hamana0509 opened this issue Jun 25, 2024 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@Hamana0509
Copy link

I tried to load and combine ReFT modules to the base model. I got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], <a href='vscode-notebook-cell:?execution_count=5&line=28'>line 28</a>
     <a href='vscode-notebook-cell:?execution_count=5&line=20'>20</a> # Load model
     <a href='vscode-notebook-cell:?execution_count=5&line=21'>21</a> model = AutoModelForCausalLM.from_pretrained(
     <a href='vscode-notebook-cell:?execution_count=5&line=22'>22</a>     model_name,
     <a href='vscode-notebook-cell:?execution_count=5&line=23'>23</a>     quantization_config=quant_config,
     <a href='vscode-notebook-cell:?execution_count=5&line=24'>24</a>     device_map="auto",
     <a href='vscode-notebook-cell:?execution_count=5&line=25'>25</a>     trust_remote_code=True,
     <a href='vscode-notebook-cell:?execution_count=5&line=26'>26</a> )
---> <a href='vscode-notebook-cell:?execution_count=5&line=28'>28</a> reft_model = pyreft.ReftModel.load(modules_path, model, from_huggingface_hub=True)

File /opt/conda/lib/python3.10/site-packages/pyreft/reft_model.py:26, in ReftModel.load(*args, **kwargs)
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/reft_model.py:24'>24</a> @staticmethod
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/reft_model.py:25'>25</a> def load(*args, **kwargs):
---> <a href='/opt/conda/lib/python3.10/site-packages/pyreft/reft_model.py:26'>26</a>     model = pv.IntervenableModel.load(*args, **kwargs)
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/reft_model.py:27'>27</a>     return ReftModel._convert_to_reft_model(model)

File /opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:569, in IntervenableModel.load(load_directory, model, local_directory, from_huggingface_hub)
    <a href='/opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:567'>567</a>     elif isinstance(intervention, TrainableIntervention):
    <a href='/opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:568'>568</a>         saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
--> <a href='/opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:569'>569</a>         intervention.load_state_dict(saved_state_dict)
    <a href='/opt/conda/lib/python3.10/site-packages/pyvene/models/intervenable_base.py:571'>571</a> return intervenable

File /opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:68, in LoreftIntervention.load_state_dict(self, state_dict, *args, **kwargs)
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:64'>64</a> def load_state_dict(self, state_dict, *args, **kwargs):
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:65'>65</a>     """
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:66'>66</a>     Overwrite for data-efficiency.
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:67'>67</a>     """
---> <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:68'>68</a>     self.learned_source.load_state_dict(state_dict, strict=False)
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:69'>69</a>     overload_w = state_dict["rotate_layer"]
     <a href='/opt/conda/lib/python3.10/site-packages/pyreft/interventions.py:70'>70</a>     overload_w_width = overload_w.shape[-1]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2153, in Module.load_state_dict(self, state_dict, strict, assign)
   <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2148'>2148</a>         error_msgs.insert(
   <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2149'>2149</a>             0, 'Missing key(s) in state_dict: {}. '.format(
   <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2150'>2150</a>                 ', '.join(f'"{k}"' for k in missing_keys)))
   <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2152'>2152</a> if len(error_msgs) > 0:
-> <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2153'>2153</a>     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2154'>2154</a>                        self.__class__.__name__, "\n\t".join(error_msgs)))
   <a href='/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:2155'>2155</a> return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Linear:
	size mismatch for weight: copying a param with shape torch.Size([4, 4096]) from checkpoint, the shape in current model is torch.Size([2, 4096]).
	size mismatch for bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([2]).

My saving ReFT modules code:

reft_model.save(
    save_directory="./reft_to_share",
    save_to_hf_hub=True,
    hf_repo_name="Hamana0509/ReFT_Orpo_Llama3_8B_Instruct"
)

My combine ReFT modules code:

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
modules_path = "Hamana0509/ReFT_Orpo_Llama3_8B_Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    model_max_length=2048,
    padding_side="right",
    trust_remote_code=True,
    use_fast=False,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Define quantization config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quant_config,
    device_map="auto",
    trust_remote_code=True,
)

reft_model = pyreft.ReftModel.load(modules_path, model, from_huggingface_hub=True)
@frankaging frankaging changed the title Error(s) in loading state_dict for Linear [P1] Error(s) in loading state_dict for Linear Jun 25, 2024
@frankaging frankaging self-assigned this Jun 25, 2024
@frankaging frankaging added the question Further information is requested label Jun 25, 2024
@frankaging
Copy link
Collaborator

@Hamana0509 Thanks for raising the issue. I probably need more info to debug here.

I checked your published HF model json file: https://huggingface.co/Hamana0509/ReFT_Orpo_Llama3_8B_Instruct/blob/main/config.json

It seems like the intervention has a low rank dimension size of 2. And the saved weights have a low rank dimension size of 4.

Did you overwrite those saved weights somehow? If not, could you try to randomly initialize a model, and save the initialized model, and reload again? And you can also manually check the dimension of saved weights by using the torch.load() API.

@Hamana0509
Copy link
Author

@frankaging Thank you for answering my question. Here is my source code:
https://colab.research.google.com/drive/1ZKzqVF2d4L1YJNN5_jqnvr4uc1N3EbGj?usp=sharing

I use ORPOReftTrainer, which is re-implemented from ORPOTrainer of the trl package, I don't know if my trainer overrides any model parameters anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants