diff --git a/xtuner/model/dpo.py b/xtuner/model/dpo.py index 9384ddb34..faaa43402 100644 --- a/xtuner/model/dpo.py +++ b/xtuner/model/dpo.py @@ -52,10 +52,13 @@ def __init__(self, self.beta = beta if ref_llm is not None: - ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings")) + ref_llm = self.build_llm_from_cfg( + ref_llm, kwargs.get('use_varlen_attn', False), + kwargs.get('max_position_embeddings', None)) self.ref_llm = disable_grad(ref_llm) else: - self.ref_llm = None if self.use_lora else create_reference_model(self.llm) + self.ref_llm = None if self.use_lora else create_reference_model( + self.llm) def _gather_masked_logits(self, logits, labels, mask): logits = torch.gather( diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 9c3fa38c9..522950489 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -80,7 +80,8 @@ def __init__(self, max_position_embeddings=None): super().__init__() - self.llm = self._build_llm_from_cfg(llm, use_varlen_attn, max_position_embeddings) + self.llm = self.build_llm_from_cfg(llm, use_varlen_attn, + max_position_embeddings) if tokenizer is not None: if isinstance(tokenizer, dict): @@ -115,19 +116,19 @@ def __init__(self, # the sequence. self.use_varlen_attn = use_varlen_attn - - def _build_llm_from_cfg(self, llm_cfg, use_varlen_attn, max_position_embeddings): + def build_llm_from_cfg(self, llm_cfg, use_varlen_attn, + max_position_embeddings): # For forward with LoadWoInit(): if isinstance(llm_cfg, dict): - llm = self._dispatch_lm_model_cfg(llm_cfg, max_position_embeddings) + llm = self._dispatch_lm_model_cfg(llm_cfg, + max_position_embeddings) llm = self._build_from_cfg_or_module(llm) llm.config.use_cache = False dispatch_modules(llm, use_varlen_attn=use_varlen_attn) return llm - def gradient_checkpointing_enable(self): self.activation_checkpointing_enable()