Skip to content

Commit

Permalink
fix(accelerate_ppo_trainer): no resizing when using peft reference
Browse files Browse the repository at this point in the history
  • Loading branch information
congchan authored Jul 4, 2023
1 parent d2c4b3b commit fd58c49
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ def __init__(self, config: TRLConfig, **kwargs):

# Set up a reference model when hydra heads are not used
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
# Full Reference Copy
self.ref_model = self.get_arch(self.config)
self.ref_model.base_model.resize_token_embeddings(len(self.tokenizer))
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()
else:
# resize hydra heads
elif hasattr(self.model, "frozen_head"):
# Hydra Reference: Use the frozen base layers and head as the reference model, resize hydra heads
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
# TODO: else PEFT Reference, do something?

# Set up the KL controller
# This helps prevent large divergences in the controller (policy)
Expand Down

0 comments on commit fd58c49

Please sign in to comment.