forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add KD distributed recipe (pytorch#1631)
- Loading branch information
1 parent
48a8449
commit 09c2619
Showing
5 changed files
with
1,502 additions
and
0 deletions.
There are no files selected for viewing
130 changes: 130 additions & 0 deletions
130
recipes/configs/llama3_2/knowledge_distillation_distributed.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py | ||
# using a teacher and student model | ||
# | ||
# This config assumes that you've ran the following commands before launching KD: | ||
# First download the student and teacher models | ||
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# | ||
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora | ||
# | ||
# To launch on 2 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed | ||
# | ||
# This config works best for distilling on 2+ devices. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_2.lora_llama3_2_1b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 64 | ||
lora_alpha: 128 | ||
lora_dropout: 0.0 | ||
|
||
teacher_model: | ||
_component_: torchtune.models.llama3_1.llama3_1_8b | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/ | ||
checkpoint_files: [ | ||
model.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Llama-3.2-1B-Instruct/ | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
# Teacher checkpoint | ||
teacher_checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
checkpoint_files: [ | ||
model-00001-of-00004.safetensors, | ||
model-00002-of-00004.safetensors, | ||
model-00003-of-00004.safetensors, | ||
model-00004-of-00004.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
model_type: LLAMA3 | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 4 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
kd_loss: | ||
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss | ||
kd_ratio: 0.5 | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 32 | ||
|
||
# Logging | ||
output_dir: /tmp/kd_output | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
|
||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
123 changes: 123 additions & 0 deletions
123
recipes/configs/qwen2/knowledge_distillation_distributed.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py | ||
# using a teacher and student model | ||
# | ||
# This config assumes that you've ran the following commands before launching KD: | ||
# First download the student and teacher models | ||
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None | ||
# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None | ||
# | ||
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora | ||
# | ||
# To launch on 2 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed | ||
# | ||
# This config works best for distilling on 2+ devices. | ||
|
||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.qwen2.lora_qwen2_0_5b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
lora_rank: 32 | ||
lora_alpha: 64 | ||
|
||
teacher_model: | ||
_component_: torchtune.models.qwen2.qwen2_1_5b | ||
|
||
tokenizer: | ||
_component_: torchtune.models.qwen2.qwen2_tokenizer | ||
path: /tmp/Qwen2-0.5B-Instruct/vocab.json | ||
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt | ||
max_seq_len: null | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct | ||
checkpoint_files: [ | ||
model.safetensors | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-0.5B-Instruct-kd | ||
model_type: QWEN2 | ||
|
||
teacher_checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
checkpoint_files: [ | ||
hf_model_0001_0.pt | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
model_type: QWEN2 | ||
|
||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_cleaned_dataset | ||
seed: null | ||
shuffle: True | ||
batch_size: 8 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
kd_loss: | ||
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss | ||
kd_ratio: 0.5 | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 2 | ||
|
||
# Logging | ||
output_dir: /tmp/qwen_kd | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: False | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
|
||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
Oops, something went wrong.