Skip to content

Commit

Permalink
feat: finish script
Browse files Browse the repository at this point in the history
  • Loading branch information
asawczyn committed Apr 9, 2024
1 parent 9945e90 commit c633615
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ module = [
"datasets",
]
ignore_missing_imports = true

28 changes: 10 additions & 18 deletions scripts/fine_tune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@
import typer
from peft.tuners.lora.config import LoraConfig
from trl import SFTTrainer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
PreTrainedTokenizer,
PreTrainedModel,
Trainer,
)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from datasets import load_dataset, DatasetDict, Dataset, IterableDatasetDict, IterableDataset
assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'

from transformers import TrainingArguments
from datasets import load_dataset

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments

def main(
model_name: str = typer.Option(
Expand Down Expand Up @@ -59,10 +54,8 @@ def get_model_and_tokenizer(
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
Expand All @@ -71,15 +64,14 @@ def get_model_and_tokenizer(
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.padding_side = "right" # to prevent warnings

return model, tokenizer


def get_peft_config() -> LoraConfig:
def get_peft_config():
peft_config = LoraConfig(
lora_alpha=8,
lora_dropout=0.05,
Expand Down Expand Up @@ -134,7 +126,7 @@ def get_trainer(
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": False, # No need to add additional separator token
},
}
)

return trainer
Expand Down

0 comments on commit c633615

Please sign in to comment.