You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue is that ReftTrainer.save_model does not save the ReftConfig, only the intervention.
As a workaround, we can load the model from the checkpoint using the following code (by reinstantiating the config manually):
import pyreft
import pyvene as pv
reft_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype= torch.bfloat16, device_map="cuda")
reft_config = pyreft.ReftConfig(representations={
"layer": 15, "component": "block_output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(embed_dim=reft_model.config.hidden_size,
low_rank_dimension=4)})
reft_model = pv.IntervenableModel(reft_config, reft_model)
reft_model.load_intervention('./tmp/checkpoint-78/intervenable_model')
device = 'cuda'
for k, v in reft_model.interventions.items():
v[0].to(device)
Please let me know if I am missing something!
Thanks,
Bryan
The text was updated successfully, but these errors were encountered:
frankaging
changed the title
Can't load ReftModel from checkpoint trained by ReftTrainer
[P1] Can't load ReftModel from checkpoint trained by ReftTrainer
Jun 13, 2024
@BryanWBear Yes! I am turning this ticket into a feature request, which i can work on it later. Thanks for bringing this up.
For now, to save your reft model, you can also try reft_model .save(<your_dir>) to save by using our own API instead of the trainer's API. I think this API will save the config as well as other artifacts.
frankaging
changed the title
[P1] Can't load ReftModel from checkpoint trained by ReftTrainer
[P1] Refactor ReftTrainer to save artifacts with the config
Jun 13, 2024
The issue is that
ReftTrainer.save_model
does not save theReftConfig
, only the intervention.As a workaround, we can load the model from the checkpoint using the following code (by reinstantiating the config manually):
Please let me know if I am missing something!
Thanks,
Bryan
The text was updated successfully, but these errors were encountered: