diff --git a/realhf/impl/model/interface/ppo_interface.py b/realhf/impl/model/interface/ppo_interface.py index 7e5a71d6..b557f3d7 100755 --- a/realhf/impl/model/interface/ppo_interface.py +++ b/realhf/impl/model/interface/ppo_interface.py @@ -267,7 +267,7 @@ def inference( # This post_hook will gather log probabilities in mini-batches, # reducing peak memory usage. def calc_logprobs(logits, input_): - logits /= self.generation_config.temperature + logits /= self.gconfig.temperature if ( "packed_logits_mask" in input_.data and input_.data["packed_logits_mask"] is not None