Skip to content

Commit

Permalink
Update reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Nov 23, 2024
1 parent c5d3de4 commit 01466a5
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN mkdir -p /packages && \
cd /packages && \
git clone https://github.com/truefoundry/axolotl && \
cd axolotl/ && \
git checkout 0011a3969eeceffc5140315a41d64048cca0ac30 && \
git checkout 1442cf55cd7b1c257a525450996355473fdcbb35 && \
cd /packages/axolotl/ && \
MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --use-pep517 --no-build-isolation --no-cache-dir -e .[flash-attn,mamba-ssm,optimizers,lion-pytorch,galore] && \
rm -rf /root/.cache/pip
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile-notebook
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ USER jovyan
RUN cd /packages && \
git clone https://github.com/truefoundry/axolotl && \
cd axolotl/ && \
git checkout 0011a3969eeceffc5140315a41d64048cca0ac30 && \
git checkout 1442cf55cd7b1c257a525450996355473fdcbb35 && \
cd /packages/axolotl/ && \
MAX_JOBS=1 NVCC_APPEND_FLAGS="--threads 1" pip install -U --use-pep517 --no-build-isolation --no-cache-dir -e .[flash-attn,mamba-ssm,optimizers,lion-pytorch,galore]

Expand Down
1 change: 1 addition & 0 deletions config-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ cleanup_output_dir_on_start: False
dataset_type: chat # Can be completion | chat
drop_long_sequences: False
logging_dir: ./tensorboard_logs
merge_adapters_post_train: True
save_model_on_interrupt: False
train_data_uri: null
truefoundry_ml_checkpoint_artifact_name: auto # type: string
Expand Down
24 changes: 19 additions & 5 deletions plugins/axolotl_truefoundry/axolotl_truefoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks import GPUStatsCallback
from axolotl.utils.distributed import is_main_process
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from transformers import Trainer, TrainerCallback
from transformers.integrations import rewrite_logs
from transformers.integrations.integration_utils import TensorBoardCallback
Expand Down Expand Up @@ -184,12 +184,16 @@ class LongSequenceStrategy(str, enum.Enum):
class TruefoundryMLPluginArgs(BaseModel):
model_config = ConfigDict(use_enum_values=True)

cleanup_output_dir_on_start: bool = False
logging_dir: str = "./tensorboard_logs"

dataset_type: DatasetType = DatasetType.chat
train_data_uri: Optional[str]
val_data_uri: Optional[str] = None
val_set_size: float = 0.1

long_sequences_strategy: LongSequenceStrategy = LongSequenceStrategy.error
merge_adapters_post_train: bool = True

truefoundry_ml_enable_reporting: bool = False
truefoundry_ml_repo: Optional[str] = None
Expand All @@ -199,10 +203,20 @@ class TruefoundryMLPluginArgs(BaseModel):
truefoundry_ml_log_merged_model: bool = True
truefoundry_ml_log_gpu_metrics: bool = False

cleanup_output_dir_on_start: bool = False
logging_dir: str = "./tensorboard_logs"

truefoundry_testing_mode: bool = False
@model_validator(mode="before")
@classmethod
def check_merging_settings(cls, data: Any) -> Any:
if isinstance(data, dict):
if (
data.get("truefoundry_ml_enable_reporting")
and data.get("truefoundry_ml_log_merged_model")
and not data.get("merge_adapters_post_train")
):
raise ValueError(
"Cannot log merged model if merge_adapters_post_train is False. "
"Please set merge_adapters_post_train to True or disable logging merged model."
)
return data


class TrueFoundryMLPlugin(BasePlugin):
Expand Down
90 changes: 60 additions & 30 deletions reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A very hacky script to test out capturing GPU memory usage against tokens and trainable parameters
Later this will be more automated and parallelized across TrueFoundry Jobs
"""
import argparse
import itertools
import json
import os
Expand All @@ -11,20 +12,27 @@
from concurrent.futures import ThreadPoolExecutor
from typing import List

import pandas as pd
import yaml
from pydantic import BaseModel
from transformers import AutoConfig

ML_REPO = "llm-ft-reporting"

MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
]
SEQ_LENS = [512, 1024, 2048, 4096, 8192]
LORA_RS = [8, 16, 32]
class ReportingConfig(BaseModel):
ml_repo: str = "llm-ft-reporting"
base_models: List[str] = [
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
]
sequence_lens: List[int] = [512, 1024, 2048, 4096, 8192]
lora_rs: List[int] = [32]
stream_stdout: bool = False
stream_stderr: bool = False


COMMAND = """\
accelerate launch
Expand All @@ -37,13 +45,14 @@
--dataset_type chat
--train_data_uri ./sample_data/chatalpaca-openai-1k.jsonl
--val_data_uri None
--val_set_size 0.1
--val_set_size 0.2
--sequence_len {sequence_len}
--long_sequences_strategy drop
--micro_batch_size 1
--eval_batch_size 1
--eval_sample_packing True
--num_epochs 1
--max_steps 10
--max_steps 3
--gradient_accumulation_steps 4
--gradient_checkpointing unsloth
--learning_rate 0.00001
Expand All @@ -67,7 +76,7 @@
--truefoundry_ml_log_checkpoints False
--truefoundry_ml_log_gpu_metrics True
--truefoundry_ml_log_merged_model False
--truefoundry_testing_mode True
--merge_adapters_post_train False
"""


Expand All @@ -77,7 +86,7 @@ def stream_output(pipe, prefix=""):
pipe.close()


def run_command(command: List[str]):
def run_command(command: List[str], stream_stdout=False, stream_stderr=False):
print("Running command: ", " ".join(command))
try:
process = subprocess.Popen(
Expand All @@ -91,10 +100,11 @@ def run_command(command: List[str]):
shell=True,
)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(stream_output, process.stdout, "STDOUT: "),
# executor.submit(stream_output, process.stderr, "STDERR: "),
]
futures = []
if stream_stdout:
futures.append(executor.submit(stream_output, process.stdout, "STDOUT: "))
if stream_stderr:
futures.append(executor.submit(stream_output, process.stderr, "STDERR: "))
process.wait()
for future in futures:
future.result()
Expand All @@ -107,14 +117,25 @@ def run_command(command: List[str]):


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="reporting_config.yaml")
parser.add_argument("--output", type=str, default="report.csv")
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
config = ReportingConfig.model_validate(config)
env = {
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True,roundup_power2_divisions:16",
"CUDA_VISIBLE_DEVICES": "0",
"TORCH_PER_PROCESS_MEMORY_LIMIT": "0.98",
"GPU_CLEANUP_N_ITERS": "3",
"GPU_CLEANUP_INTERVAL_SECONDS": "3",
}
for k, v in env.items():
os.environ[k] = v
for model, seq_len, lora_r in itertools.product(MODELS, SEQ_LENS, LORA_RS):

reports = []
for model, seq_len, lora_r in itertools.product(config.base_models, config.sequence_lens, config.lora_rs):
if os.path.exists("axolotl_truefoundry.plugin.log"):
os.remove("axolotl_truefoundry.plugin.log")
if os.path.exists("train.log"):
Expand All @@ -126,12 +147,14 @@ def main():
sequence_len=str(seq_len),
lora_r=str(lora_r),
lora_alpha=str(lora_r * 2),
ml_repo=ML_REPO,
ml_repo=config.ml_repo,
run_name=run_name,
)
try:
run_command(
shlex.split(command),
stream_stdout=config.stream_stdout,
stream_stderr=config.stream_stderr,
)
except Exception as e:
print(f"Failed to run command: {e}")
Expand All @@ -156,23 +179,30 @@ def main():
if "CUDA out of memory. Tried to allocate" in line:
cuda_oom = True
break
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
model_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
report = {
"base_model": model,
"seq_len": seq_len,
"lora_r": lora_r,
"trainable_params": trainable_params,
"all_params": all_params,
"cuda_oom": cuda_oom,
"max_gpu_memory_allocated": max_gpu_memory_allocated,
"model_config": json.loads(model_config.to_json_string()),
}
print("=" * 80)
print(f"Config: {config}")
print(f"Model: {model}")
print(f"Seq Len: {seq_len}")
print(f"LoRA R: {lora_r}")
print(f"Trainable Params: {trainable_params}")
print(f"All Params: {all_params}")
print(f"CUDA OOM: {cuda_oom}")
print(f"GPU Memory Allocated: {max_gpu_memory_allocated}")
print(json.dumps(report))
print("=" * 80)
reports.append(report)
if not trainable_params or not all_params:
raise Exception("Failed to capture params")

if not cuda_oom and max_gpu_memory_allocated == -1:
raise Exception("Failed to capture GPU memory usage")

df = pd.DataFrame(reports)
df.to_csv(args.output, index=False)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ snowflake-connector-python[pandas]==3.12.3
torch==2.3.1+cu121
torchao==0.6.1+cu121
truefoundry>=0.4.8,<0.5.0
unsloth @ git+https://github.com/unslothai/unsloth@38663b01f5dd0e610b12475bd95b144303cff539
unsloth @ git+https://github.com/unslothai/unsloth@c2b185e7dbe04cdf2b95c681f42416dbe19d5f97
4 changes: 1 addition & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,12 @@ def _train_with_truefoundry(config_base: Path = Path("examples/"), **kwargs):
barrier()
if is_main_process():
cfg = load_config_file(path=axolotl_config)
if cfg.truefoundry_testing_mode is True:
return
model_dir = cfg.output_dir
log_step = get_step_for_final_model(
output_dir=cfg.output_dir, load_best_model_at_end=cfg.load_best_model_at_end
)
cleanup_checkpoints(output_dir=cfg.output_dir)
if cfg.adapter in {"lora", "qlora"}:
if cfg.adapter in {"lora", "qlora"} and cfg.merge_adapters_post_train:
with temporarily_unset_distributed_envs():
axolotl_merge_lora_cli(config=axolotl_config, device_map="auto")
model_dir = os.path.join(model_dir, "merged")
Expand Down
5 changes: 4 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def get_gpu_metrics():
return metrics


def try_cleanup_gpus(n_iters=6, interval_seconds=10):
def try_cleanup_gpus(
n_iters=int(os.getenv("GPU_CLEANUP_N_ITERS", 6)),
interval_seconds=int(os.getenv("GPU_CLEANUP_INTERVAL_SECONDS", 10)),
):
for _ in range(n_iters):
gc.collect()
time.sleep(interval_seconds)
Expand Down

0 comments on commit 01466a5

Please sign in to comment.