Skip to content

Commit

Permalink
add support for mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
tohrnii committed Feb 21, 2024
1 parent a030e80 commit a55b740
Show file tree
Hide file tree
Showing 4 changed files with 1,072 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

**/runpod
219 changes: 219 additions & 0 deletions examples/Alpaca_Mixtral.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c4b2a910-40ce-48f9-91b6-11d5eec547f4",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('/workspace/unsloth')\n",
"from unsloth.models.mixtral import FastMixtralModel\n",
"from unsloth import FastLanguageModel\n",
"import torch\n",
"max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!\n",
"dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
"load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.\n",
"\n",
"model, tokenizer = FastMixtralModel.from_pretrained(\n",
" model_name = \"mistralai/Mixtral-8x7B-v0.1\", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B\n",
" max_seq_length = max_seq_length,\n",
" dtype = dtype,\n",
" load_in_4bit = load_in_4bit,\n",
" # token = \"hf_...\", # use one if using gated models like meta-llama/Llama-2-7b-hf\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77182aa0-762f-4e80-bdf8-8af785fc6f97",
"metadata": {},
"outputs": [],
"source": [
"model = FastMixtralModel.get_peft_model(\n",
" model,\n",
" r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
" # \"gate\", \"w1\", \"w2\", \"w3\"],\n",
" lora_alpha = 16,\n",
" lora_dropout = 0, # Supports any, but = 0 is optimized\n",
" bias = \"none\", # Supports any, but = \"none\" is optimized\n",
" use_gradient_checkpointing = True,\n",
" random_state = 3407,\n",
" use_rslora = False, # We support rank stabilized LoRA\n",
" loftq_config = None, # And LoftQ\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4847ef2c-cdf5-4905-899d-9fc331fde245",
"metadata": {},
"outputs": [],
"source": [
"alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
"\n",
"### Instruction:\n",
"{}\n",
"\n",
"### Input:\n",
"{}\n",
"\n",
"### Response:\n",
"{}\"\"\"\n",
"\n",
"EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN\n",
"def formatting_prompts_func(examples):\n",
" instructions = examples[\"instruction\"]\n",
" inputs = examples[\"input\"]\n",
" outputs = examples[\"output\"]\n",
" texts = []\n",
" for instruction, input, output in zip(instructions, inputs, outputs):\n",
" # Must add EOS_TOKEN, otherwise your generation will go on forever!\n",
" text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN\n",
" texts.append(text)\n",
" return { \"text\" : texts, }\n",
"pass\n",
"\n",
"from datasets import load_dataset\n",
"dataset = load_dataset(\"yahma/alpaca-cleaned\", split = \"train\")\n",
"dataset = dataset.map(formatting_prompts_func, batched = True,)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "508de7d8-64be-407d-8899-f8c737ed3650",
"metadata": {},
"outputs": [],
"source": [
"from trl import SFTTrainer\n",
"from transformers import TrainingArguments\n",
"\n",
"trainer = SFTTrainer(\n",
" model = model,\n",
" tokenizer = tokenizer,\n",
" train_dataset = dataset,\n",
" dataset_text_field = \"text\",\n",
" max_seq_length = max_seq_length,\n",
" dataset_num_proc = 2,\n",
" packing = False, # Can make training 5x faster for short sequences.\n",
" args = TrainingArguments(\n",
" per_device_train_batch_size = 2,\n",
" gradient_accumulation_steps = 4,\n",
" warmup_steps = 5,\n",
" max_steps = 60,\n",
" learning_rate = 2e-4,\n",
" fp16 = not torch.cuda.is_bf16_supported(),\n",
" bf16 = torch.cuda.is_bf16_supported(),\n",
" logging_steps = 1,\n",
" optim = \"adamw_8bit\",\n",
" weight_decay = 0.01,\n",
" lr_scheduler_type = \"linear\",\n",
" seed = 3407,\n",
" output_dir = \"outputs\",\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e33a6e5-a8b9-402c-8419-10c3e969a561",
"metadata": {},
"outputs": [],
"source": [
"#@title Show current memory stats\n",
"gpu_stats = torch.cuda.get_device_properties(0)\n",
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9406a6c0-ce05-4c1b-bbb1-60ed2b3a7418",
"metadata": {},
"outputs": [],
"source": [
"trainer_stats = trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47f6cd90-4b6a-4165-a934-ee8d630d1f9d",
"metadata": {},
"outputs": [],
"source": [
"#@title Show final memory and time stats\n",
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
"used_percentage = round(used_memory /max_memory*100, 3)\n",
"lora_percentage = round(used_memory_for_lora/max_memory*100, 3)\n",
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
"print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e13e692-e4f7-46d1-9df0-870695fd7e9e",
"metadata": {},
"outputs": [],
"source": [
"# alpaca_prompt = Copied from above\n",
"FastMixtralModel.for_inference(model) # Enable native 2x faster inference\n",
"inputs = tokenizer(\n",
"[\n",
" alpaca_prompt.format(\n",
" \"Continue the fibonnaci sequence.\", # instruction\n",
" \"1, 1, 2, 3, 5, 8\", # input\n",
" \"\", # output - leave this blank for generation!\n",
" )\n",
"], return_tensors = \"pt\").to(\"cuda\")\n",
"\n",
"outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)\n",
"tokenizer.batch_decode(outputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09978b57-d549-4888-aaba-459abb683545",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 3 additions & 1 deletion unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .mixtral import FastMixtralModel
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
Expand Down Expand Up @@ -97,8 +98,9 @@ def from_pretrained(

model_type = model_config.model_type

if model_type == "llama": dispatch_model = FastLlamaModel
if model_type == "llama": dispatch_model = FastLlamaModel
elif model_type == "mistral": dispatch_model = FastMistralModel
elif model_type == "mixtral": dispatch_model = FastMixtralModel
else:
raise NotImplementedError(
f"Unsloth: {model_name} not supported yet!\n"\
Expand Down
Loading

0 comments on commit a55b740

Please sign in to comment.