From eb7ef70c3d9a2f6a6bb199a279cc77e0f9d85791 Mon Sep 17 00:00:00 2001 From: Albert Sawczyn Date: Wed, 13 Mar 2024 11:32:45 +0000 Subject: [PATCH] feat: enable flash attention --- scripts/fine_tune_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fine_tune_llm.py b/scripts/fine_tune_llm.py index 3651b6a..822342e 100644 --- a/scripts/fine_tune_llm.py +++ b/scripts/fine_tune_llm.py @@ -47,7 +47,7 @@ def get_model_and_tokenizer() -> tuple[PreTrainedModel, PreTrainedTokenizer]: model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", - # attn_implementation="flash_attention_2", + attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, quantization_config=bnb_config, )