diff --git a/scripts/fine_tune_llm.py b/scripts/fine_tune_llm.py index 3651b6a..822342e 100644 --- a/scripts/fine_tune_llm.py +++ b/scripts/fine_tune_llm.py @@ -47,7 +47,7 @@ def get_model_and_tokenizer() -> tuple[PreTrainedModel, PreTrainedTokenizer]: model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", - # attn_implementation="flash_attention_2", + attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, quantization_config=bnb_config, )