diff --git a/Phi-3-V/train.py b/Phi-3-V/train.py index 40f31c0..9affb77 100644 --- a/Phi-3-V/train.py +++ b/Phi-3-V/train.py @@ -923,7 +923,7 @@ def train(attn_implementation=None): **bnb_model_from_pretrained_args ) else: - model = LlavaLlamaForCausalLM.from_pretrained( + model = LlavaPhiForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation,