Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] Refactor ReftTrainer to save artifacts with the config #109

Open
BryanWBear opened this issue Jun 13, 2024 · 1 comment
Open

[P1] Refactor ReftTrainer to save artifacts with the config #109

BryanWBear opened this issue Jun 13, 2024 · 1 comment
Assignees
Labels
engineering enhancement New feature or request

Comments

@BryanWBear
Copy link

The issue is that ReftTrainer.save_model does not save the ReftConfig, only the intervention.

As a workaround, we can load the model from the checkpoint using the following code (by reinstantiating the config manually):

import pyreft
import pyvene as pv

reft_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype= torch.bfloat16, device_map="cuda")
reft_config = pyreft.ReftConfig(representations={
    "layer": 15, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=reft_model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pv.IntervenableModel(reft_config, reft_model)
reft_model.load_intervention('./tmp/checkpoint-78/intervenable_model')

device = 'cuda'
for k, v in reft_model.interventions.items():
    v[0].to(device)

Please let me know if I am missing something!

Thanks,
Bryan

@frankaging frankaging changed the title Can't load ReftModel from checkpoint trained by ReftTrainer [P1] Can't load ReftModel from checkpoint trained by ReftTrainer Jun 13, 2024
@frankaging
Copy link
Collaborator

@BryanWBear Yes! I am turning this ticket into a feature request, which i can work on it later. Thanks for bringing this up.

For now, to save your reft model, you can also try reft_model .save(<your_dir>) to save by using our own API instead of the trainer's API. I think this API will save the config as well as other artifacts.

@frankaging frankaging self-assigned this Jun 13, 2024
@frankaging frankaging changed the title [P1] Can't load ReftModel from checkpoint trained by ReftTrainer [P1] Refactor ReftTrainer to save artifacts with the config Jun 13, 2024
@frankaging frankaging added enhancement New feature or request engineering labels Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
engineering enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants