diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 8a0b5faf5..f47e761fd 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -252,7 +252,7 @@ class CausalLMOutputWithValue(ModelOutput): value: Optional[torch.FloatTensor] = None -def make_value_branch(base_model, num_value_layers_unfrozen, dtype): +def make_value_branch(base_model, num_value_layers_unfrozen, dtype=torch.float32): value_head = make_head(hf_get_hidden_size(base_model.config), 1, dtype) if num_value_layers_unfrozen == 0: return value_head @@ -1211,9 +1211,12 @@ def __init__( ): super().__init__(base_model, peft_config=peft_config) # TODO: Support Seq2Seq value branching + parameter = next(hf_get_lm_head(self.base_model).parameters()) + dtype = parameter.dtype + device = parameter.device if num_value_layers_unfrozen > 0: raise NotImplementedError("Value branches unsupported for Seq2Seq architecture") - self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1, dtype).to(device) def forward( self, diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 6e737c080..84134642b 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -10,7 +10,7 @@ import transformers -def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential: +def make_head(n_embd: int, out: int, dtype: torch.dtype = torch.float32) -> nn.Sequential: """Returns a generic sequential MLP head.""" return nn.Sequential( nn.Linear(n_embd, n_embd * 2, dtype=dtype),