Skip to content

Commit

Permalink
Ollama (#665)
Browse files Browse the repository at this point in the history
* Update llama.py

* offload

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* continued pretraining trainer

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* is_bfloat16_supported

* Update __init__.py

* Update README.md

* Update llama.py

* is_bfloat16_supported

* Update __init__.py

* Mistral v3

* Phi 3 medium

* Update chat_templates.py

* Update chat_templates.py

* Phi-3

* Update save.py

* Update README.md

Mistral v3 to Mistral v0.3

* Untrained tokens

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update save.py

* Update save.py

* Update save.py

* checkpoint

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* accelerate

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* train_dataloader

* Update llama.py

* Update llama.py

* Update llama.py

* use_fast_convert

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* remove_special_tokens

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update llama.py

* Update chat_templates.py

* Support bfloat16 GGUF

* Update save.py

* Update llama.py

* fast_forward_inference

* Update mapper.py

* Update loader.py

* Update llama.py

* Update tokenizer_utils.py

* info

* edits

* Create chat template

* Fix tokenizer

* Update tokenizer_utils.py

* fix case where gguf saving fails due to first_conversion dtype (#630)

* Support revision parameter in FastLanguageModel.from_pretrained (#629)

* support `revision` parameter

* match unsloth formatting of named parameters

* clears any selected_adapters before calling internal_model.save_pretrained (#609)

* Update __init__.py (#602)

Check for incompatible modules before importing unsloth

* Fixed unsloth/tokenizer_utils.py for chat training (#604)

* Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345)

* Add save to llama.cpp GGML to save.py.

* Fix conversion command and path of convert to GGML function.

* Add autosaving lora to the GGML function

* Create lora save function for conversion to GGML

* Test fix #2 for saving lora

* Test fix #3 to save  the lora adapters to convert to GGML

* Remove unwated tokenizer saving for conversion to ggml and added a few print statements.

* Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages.

* Positional arguments didn't work out, so reverted to older version of the code, and added a few comments.

* Test fix 1 for arch

* Test fix 2 new Mistral error.

* Test fix 3

* Revert to old version for testing.

* Upload issue test fix 1

* Fix 2 uploading ggml

* Positional ags added.

* Temporray remove positional args

* Fix upload again!!!

* Add print statements and fix link

* Make the calling name better

* Create local saving for GGML

* Add choosing directory to save local GGML.

* Fix lil variable error in the save_to_custom_dir func

* docs: Add LoraConfig parameters documentation (#619)

* llama.cpp failing (#371)

llama.cpp is failing to generate quantize versions for the trained models.

Error:

```bash
You might have to compile llama.cpp yourself, then run this again.
You do not need to close this Python program. Run the following commands in a new terminal:
You must run this in the same folder as you're saving your model.
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j
Once that's done, redo the quantization.
```

But when i do clone this with recursive it works.

Co-authored-by: Daniel Han <[email protected]>

* fix libcuda_dirs import for triton 3.0 (#227)

* fix libcuda_dirs import for triton 3.0

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update save.py

* Update __init__.py

* Update fast_lora.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update save.py

* Update save.py

* quantize now llama-quantize

* Update chat_templates.py

* Update loader.py

* Update mapper.py

* Update __init__.py

* embedding size

* Update qwen2.py

* docs

* Update README.md

* Update qwen2.py

* README: Fix minor typo. (#559)

* README: Fix minor typo.

One-character typo fix while reading.

* Update README.md

---------

Co-authored-by: Daniel Han <[email protected]>

* Update mistral.py

* Update qwen2.py

* Update qwen2.py

* Update qwen2.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update README.md

* FastMistralModel

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Auto check rope scaling

* Update llama.py

* Update llama.py

* Update llama.py

* GPU support

* Typo

* Update gemma.py

* gpu

* Multiple GGUF saving

* Update save.py

* Update save.py

* check PEFT and base

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

* Fix breaking bug in save.py with interpreting quantization_method as a string when saving to gguf (#651)

* Nightly (#649)

* Update llama.py

* offload

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* continued pretraining trainer

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* is_bfloat16_supported

* Update __init__.py

* Update README.md

* Update llama.py

* is_bfloat16_supported

* Update __init__.py

* Mistral v3

* Phi 3 medium

* Update chat_templates.py

* Update chat_templates.py

* Phi-3

* Update save.py

* Update README.md

Mistral v3 to Mistral v0.3

* Untrained tokens

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update save.py

* Update save.py

* Update save.py

* checkpoint

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update llama.py

* accelerate

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* train_dataloader

* Update llama.py

* Update llama.py

* Update llama.py

* use_fast_convert

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* remove_special_tokens

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update llama.py

* Update chat_templates.py

* Support bfloat16 GGUF

* Update save.py

* Update llama.py

* fast_forward_inference

* Update mapper.py

* Update loader.py

* Update llama.py

* Update tokenizer_utils.py

* info

* edits

* Create chat template

* Fix tokenizer

* Update tokenizer_utils.py

* fix case where gguf saving fails due to first_conversion dtype (#630)

* Support revision parameter in FastLanguageModel.from_pretrained (#629)

* support `revision` parameter

* match unsloth formatting of named parameters

* clears any selected_adapters before calling internal_model.save_pretrained (#609)

* Update __init__.py (#602)

Check for incompatible modules before importing unsloth

* Fixed unsloth/tokenizer_utils.py for chat training (#604)

* Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345)

* Add save to llama.cpp GGML to save.py.

* Fix conversion command and path of convert to GGML function.

* Add autosaving lora to the GGML function

* Create lora save function for conversion to GGML

* Test fix #2 for saving lora

* Test fix #3 to save  the lora adapters to convert to GGML

* Remove unwated tokenizer saving for conversion to ggml and added a few print statements.

* Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages.

* Positional arguments didn't work out, so reverted to older version of the code, and added a few comments.

* Test fix 1 for arch

* Test fix 2 new Mistral error.

* Test fix 3

* Revert to old version for testing.

* Upload issue test fix 1

* Fix 2 uploading ggml

* Positional ags added.

* Temporray remove positional args

* Fix upload again!!!

* Add print statements and fix link

* Make the calling name better

* Create local saving for GGML

* Add choosing directory to save local GGML.

* Fix lil variable error in the save_to_custom_dir func

* docs: Add LoraConfig parameters documentation (#619)

* llama.cpp failing (#371)

llama.cpp is failing to generate quantize versions for the trained models.

Error:

```bash
You might have to compile llama.cpp yourself, then run this again.
You do not need to close this Python program. Run the following commands in a new terminal:
You must run this in the same folder as you're saving your model.
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j
Once that's done, redo the quantization.
```

But when i do clone this with recursive it works.

Co-authored-by: Daniel Han <[email protected]>

* fix libcuda_dirs import for triton 3.0 (#227)

* fix libcuda_dirs import for triton 3.0

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: Daniel Han <[email protected]>

* Update save.py

* Update __init__.py

* Update fast_lora.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update save.py

* Update save.py

* quantize now llama-quantize

* Update chat_templates.py

* Update loader.py

* Update mapper.py

* Update __init__.py

* embedding size

* Update qwen2.py

* docs

* Update README.md

* Update qwen2.py

* README: Fix minor typo. (#559)

* README: Fix minor typo.

One-character typo fix while reading.

* Update README.md

---------

Co-authored-by: Daniel Han <[email protected]>

* Update mistral.py

* Update qwen2.py

* Update qwen2.py

* Update qwen2.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update README.md

* FastMistralModel

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Update mistral.py

* Auto check rope scaling

* Update llama.py

* Update llama.py

* Update llama.py

* GPU support

* Typo

* Update gemma.py

* gpu

* Multiple GGUF saving

* Update save.py

* Update save.py

* check PEFT and base

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

---------

Co-authored-by: Michael Han <[email protected]>
Co-authored-by: Eliot Hall <[email protected]>
Co-authored-by: Rickard Edén <[email protected]>
Co-authored-by: XiaoYang <[email protected]>
Co-authored-by: Oseltamivir <[email protected]>
Co-authored-by: mahiatlinux <[email protected]>
Co-authored-by: Sébastien De Greef <[email protected]>
Co-authored-by: Alberto Ferrer <[email protected]>
Co-authored-by: Thomas Viehmann <[email protected]>
Co-authored-by: Walter Korman <[email protected]>

* Fix bug in save.py with interpreting quantization_method as a string that prevents GGUF from saving

* Implemented better list management and then forgot to actually call the new list variable, fixed

* Check type of given quantization method and return type error if not list or string

* Update save.py

---------

Co-authored-by: Daniel Han <[email protected]>
Co-authored-by: Michael Han <[email protected]>
Co-authored-by: Eliot Hall <[email protected]>
Co-authored-by: Rickard Edén <[email protected]>
Co-authored-by: XiaoYang <[email protected]>
Co-authored-by: Oseltamivir <[email protected]>
Co-authored-by: mahiatlinux <[email protected]>
Co-authored-by: Sébastien De Greef <[email protected]>
Co-authored-by: Alberto Ferrer <[email protected]>
Co-authored-by: Thomas Viehmann <[email protected]>
Co-authored-by: Walter Korman <[email protected]>

* Revert "Fix breaking bug in save.py with interpreting quantization_method as …" (#652)

This reverts commit 30605de.

* Revert "Revert "Fix breaking bug in save.py with interpreting quantization_me…" (#653)

This reverts commit e2b2083.

* Update llama.py

* peft

* patch

* Update loader.py

* retrain

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* offload

* Update llama.py

* Create a starter script for command-line training to integrate in ML ops pipelines. (#623)

* Update chat_templates.py

* Ollama

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

---------

Co-authored-by: Michael Han <[email protected]>
Co-authored-by: Eliot Hall <[email protected]>
Co-authored-by: Rickard Edén <[email protected]>
Co-authored-by: XiaoYang <[email protected]>
Co-authored-by: Oseltamivir <[email protected]>
Co-authored-by: mahiatlinux <[email protected]>
Co-authored-by: Sébastien De Greef <[email protected]>
Co-authored-by: Alberto Ferrer <[email protected]>
Co-authored-by: Thomas Viehmann <[email protected]>
Co-authored-by: Walter Korman <[email protected]>
Co-authored-by: ArcadaLabs-Jason <[email protected]>
  • Loading branch information
12 people authored Jun 18, 2024
1 parent 64bb8cf commit 8770308
Show file tree
Hide file tree
Showing 4 changed files with 569 additions and 34 deletions.
221 changes: 221 additions & 0 deletions unsloth-cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#!/usr/bin/env python3

"""
🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
This script is designed as a starting point for fine-tuning your models using unsloth.
It includes configurable options for model loading, PEFT parameters, training arguments,
and model saving/pushing functionalities.
You will likely want to customize this script to suit your specific use case
and requirements.
Here are a few suggestions for customization:
- Modify the dataset loading and preprocessing steps to match your data.
- Customize the model saving and pushing configurations.
Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
--r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
--random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
--warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
--weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
--report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
--push_model --hub_path "hf/model" --hub_token "your_hf_token"
To see a full list of configurable options, use:
python unsloth-cli.py --help
Happy fine-tuning!
"""

import argparse

def run(args):
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
import logging
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)

# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_seq_length,
dtype=args.dtype,
load_in_4bit=args.load_in_4bit,
)

# Configure PEFT model
model = FastLanguageModel.get_peft_model(
model,
r=args.r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.bias,
use_gradient_checkpointing=args.use_gradient_checkpointing,
random_state=args.random_state,
use_rslora=args.use_rslora,
loftq_config=args.loftq_config,
)

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
instructions = examples["instruction"]
inputs = examples["input"]
outputs = examples["output"]
texts = []
for instruction, input, output in zip(instructions, inputs, outputs):
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
texts.append(text)
return {"text": texts}

# Load and format dataset
dataset = load_dataset(args.dataset, split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)
print("Data is formatted and ready!")

# Configure training arguments
training_args = TrainingArguments(
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
learning_rate=args.learning_rate,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=args.logging_steps,
optim=args.optim,
weight_decay=args.weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
seed=args.seed,
output_dir=args.output_dir,
report_to=args.report_to,
)

# Initialize trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
dataset_num_proc=2,
packing=False,
args=training_args,
)

# Train model
trainer_stats = trainer.train()

# Save model
if args.save_model:
# if args.quantization_method is a list, we will save the model for each quantization method
if args.save_gguf:
if isinstance(args.quantization, list):
for quantization_method in args.quantization:
print(f"Saving model with quantization method: {quantization_method}")
model.save_pretrained_gguf(
args.save_path,
tokenizer,
quantization_method=quantization_method,
)
if args.push_model:
model.push_to_hub_gguf(
hub_path=args.hub_path,
hub_token=args.hub_token,
quantization_method=quantization_method,
)
else:
print(f"Saving model with quantization method: {args.quantization}")
model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
if args.push_model:
model.push_to_hub_gguf(
hub_path=args.hub_path,
hub_token=args.hub_token,
quantization_method=quantization_method,
)
else:
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
if args.push_model:
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
else:
print("Warning: The model is not saved!")


if __name__ == "__main__":

# Define argument parser
parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")

model_group = parser.add_argument_group("🤖 Model Options")
model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")

lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.")
lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")


training_group = parser.add_argument_group("🎓 Training Options")
training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")


# Report/Logging arguments
report_group = parser.add_argument_group("📊 Report Options")
report_group.add_argument('--report_to', type=str, default="tensorboard",
choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")

# Saving and pushing arguments
save_group = parser.add_argument_group('💾 Save Model Options')
save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")

push_group = parser.add_argument_group('🚀 Push Model Options')
push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")

args = parser.parse_args()
run(args)
Loading

0 comments on commit 8770308

Please sign in to comment.