Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
Seq2seq PPO + dtype default value
  • Loading branch information
glerzing committed Aug 7, 2023
1 parent 4a6896b commit 9dae10c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class CausalLMOutputWithValue(ModelOutput):
value: Optional[torch.FloatTensor] = None


def make_value_branch(base_model, num_value_layers_unfrozen, dtype):
def make_value_branch(base_model, num_value_layers_unfrozen, dtype=torch.float32):
value_head = make_head(hf_get_hidden_size(base_model.config), 1, dtype)
if num_value_layers_unfrozen == 0:
return value_head
Expand Down Expand Up @@ -1211,9 +1211,12 @@ def __init__(
):
super().__init__(base_model, peft_config=peft_config)
# TODO: Support Seq2Seq value branching
parameter = next(hf_get_lm_head(self.base_model).parameters())
dtype = parameter.dtype
device = parameter.device
if num_value_layers_unfrozen > 0:
raise NotImplementedError("Value branches unsupported for Seq2Seq architecture")
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1, dtype).to(device)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import transformers


def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential:
def make_head(n_embd: int, out: int, dtype: torch.dtype = torch.float32) -> nn.Sequential:
"""Returns a generic sequential MLP head."""
return nn.Sequential(
nn.Linear(n_embd, n_embd * 2, dtype=dtype),
Expand Down

0 comments on commit 9dae10c

Please sign in to comment.