Skip to content

Commit

Permalink
feat: Added QLoRA fine-tuning pipeline (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs authored Feb 14, 2024
1 parent 4c79762 commit aac9880
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 38 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,6 @@ _html/

# mlflow output
mlruns/

#tensorflow output
results/
12 changes: 12 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,15 @@ ignore_missing_imports = True

[mypy-sentence_transformers.*]
ignore_missing_imports = True

[mypy-datasets.*]
ignore_missing_imports = True

[mypy-trl.*]
ignore_missing_imports = True

[mypy-peft.*]
implicit_reexport = True

[mypy-numba.*]
ignore_missing_imports = True
33 changes: 29 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ Notes:
the Python Project Template documentation on
[Sphinx and Python Notebooks](https://lincc-ppt.readthedocs.io/en/latest/practices/sphinx.html#python-notebooks)

## Models

The models are hosted in the [autora-doc](https://huggingface.co/autora-doc) Huggingface organization.

## Usage

Once the package is installed, documentation can be generated through the `autodoc` CLI tool:

```sh
autodoc generate <autora python file>
```

## Running on Colab

A notebook for testing different prompts can be run on Google Colab through [this link](https://colab.research.google.com/github/AutoResearch/autodoc/blob/main/notebooks/generate.ipynb). Be sure to change the Runtime type to a T4 GPU.


## Running AzureML pipelines

Expand All @@ -73,16 +89,25 @@ az configure --defaults workspace=<aml workspace> group=<resource group> locatio

### Running jobs

Prediction
Inference
```sh
az ml job create -f azureml/eval.yml --set display_name="Test prediction job" --set environment_variables.HF_TOKEN=<your huggingface token> --web
az ml job create -f azureml/generate.yml --set display_name="Test inference job"
```

Notes:
Evaluation
```sh
az ml job create -f azureml/eval.yml --set display_name="Test evaluation job"
```

Fine-Tuning (training)
```sh
az ml job create -f azureml/train.yml --set display_name="Test training job"
```

Additional arguments:
- `--name` will set the mlflow run id
- `--display_name` becomes the name in the experiment dashboard
- `--web` argument will pop-up a browser window for tracking the job.
- The `HF_TOKEN` is required for gated repos, which need authentication


### Uploading data
Expand Down
4 changes: 4 additions & 0 deletions azureml/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ dependencies:
# This works, while installing from pytorch and cuda from conda does not
- torch==2.0.1
- sentence_transformers>=2.3.1
- datasets
- peft>=0.8.2
- trl>=0.7.10
- tensorboardX
30 changes: 30 additions & 0 deletions azureml/train.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
command: >
python -m autora.doc.pipelines.main train
${{inputs.new_model_name}}
${{inputs.data_dir}}/data.jsonl
--base-model ${{inputs.model_path}}
code: ../src
inputs:
data_dir:
type: uri_folder
path: azureml://datastores/workspaceblobstore/paths/data/autora
model_path: autora-doc/Llama-2-7b-chat-hf-nf4
new_model_name: autora-doc/Llama-2-7b-chat-hf-nf4-ft
environment_variables:
PYTORCH_CUDA_ALLOC_CONF: max_split_size_mb:128
# using a curated environment doesn't work because we need additional packages
environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
# These didn't work
# image: mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu117-py38-torch201:biweekly.202310.3
# image: mcr.microsoft.com/azureml/curated/acpt-pytorch-1.13-cuda11.7:latest
# image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.0.3-cudnn8-ubuntu18.04
# image: mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04
# image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.6-cudnn8-ubuntu20.04
# image: nvcr.io/nvidia/pytorch:23.10-py3
conda_file: conda.yml
display_name: autodoc_train
compute: azureml:v100cluster
experiment_name: train
description: |
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ dev = [
"ipykernel",
"hf_transfer",
]
pipelines = ["jsonlines", "mlflow", "nltk", "sentence-transformers>=2.3.1"]
pipelines = [
"jsonlines",
"mlflow",
"nltk",
"sentence-transformers>=2.3.1",
"peft>=0.8.2",
"trl>=0.7.10",
"datasets",
"tensorboardX",
]
# NOTE: When updating dependencies, in particular cuda/azure ml, make sure to update the azureml/conda.yaml too
azure = ["azureml-core", "azureml-mlflow"]
cuda = ["bitsandbytes>=0.42.0", "accelerate>=0.24.1", "xformers"]
Expand Down
21 changes: 21 additions & 0 deletions src/autora/doc/pipelines/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Iterable, List, Tuple


def load_data(data_file: str) -> Tuple[List[str], List[str]]:
import jsonlines

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [item["instruction"] for item in items]
labels = [item["output"] for item in items]
return inputs, labels


def preprocess_code(code: str) -> str:
lines: Iterable[str] = code.splitlines()
skip_starts = {"import", "from", "#"}
lines = filter(
lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""),
lines,
)
return "\n".join(lines)
37 changes: 23 additions & 14 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import itertools
import logging
from timeit import default_timer as timer
from typing import Dict, List, Tuple
from typing import Dict, List

import torch
import typer

from autora.doc.classes.EvalResult import EvalResult
from autora.doc.pipelines.data import load_data
from autora.doc.pipelines.metrics import eval_bleu_meteor, eval_semscore
from autora.doc.pipelines.train import fine_tune, get_dataset
from autora.doc.runtime.predict_hf import Predictor
from autora.doc.runtime.prompts import PROMPTS, PromptIds
from autora.doc.util import get_prompts_from_file

# For inference
DEFAULT_INFERENCE_MODEL = "meta-llama/Llama-2-7b-chat-hf"
# For training
DEFAULT_BASE_MODEL = "autora-doc/Llama-2-7b-chat-hf"

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s",
Expand All @@ -24,7 +31,7 @@
@app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file")
def eval_prompts(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"),
prompts_file: str = typer.Argument(..., help="JSON file with a list of dictionary of prompts"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
Expand Down Expand Up @@ -62,7 +69,7 @@ def eval_prompts(
@app.command(help="Evaluate model on a data file")
def eval(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"),
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
Expand All @@ -89,16 +96,6 @@ def eval(
return eval_prompt(data_file, pred, prompt, param_dict)


def load_data(data_file: str) -> Tuple[List[str], List[str]]:
import jsonlines

with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [f"{item['instruction']}" for item in items]
labels = [item["output"] for item in items]
return inputs, labels


def eval_prompt(data_file: str, pred: Predictor, prompt: str, param_dict: Dict[str, float]) -> EvalResult:
import mlflow

Expand Down Expand Up @@ -132,7 +129,7 @@ def eval_prompt(data_file: str, pred: Predictor, prompt: str, param_dict: Dict[s
@app.command()
def generate(
python_file: str = typer.Argument(..., help="Python file to generate documentation for"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"),
output: str = typer.Option("output.txt", help="Output file"),
prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"),
param: List[str] = typer.Option(
Expand Down Expand Up @@ -160,6 +157,18 @@ def import_model(model_name: str) -> None:
pass


@app.command()
def train(
new_model_name: str = typer.Argument(..., help="File name for the fine-tuned model"),
dataset: str = typer.Argument(..., help="Path to the jsonl file with training data"),
base_model: str = typer.Option(
DEFAULT_BASE_MODEL, help="Path to the base Huggingface model to fine-tune"
),
) -> None:
ds = get_dataset(dataset)
fine_tune(base_model, new_model_name, ds)


@app.command()
def import_data(code_file: str, text_file: str, output_file: str = "data.jsonl") -> None:
from pathlib import Path
Expand Down
89 changes: 89 additions & 0 deletions src/autora/doc/pipelines/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Dict, Iterable

import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer

from autora.doc.pipelines.data import load_data, preprocess_code
from autora.doc.runtime.predict_hf import get_quantization_config
from autora.doc.runtime.prompts import INSTR_SWEETP_1, SYS_GUIDES, PromptBuilder


def get_dataset(jsonl_path: str) -> Dataset:
# "instruction", "output"
inputs, labels = load_data(jsonl_path)

def gen() -> Iterable[Dict[str, str]]:
for i, o in zip(inputs, labels):
text = PromptBuilder(SYS_GUIDES, INSTR_SWEETP_1).add_example(preprocess_code(i), o).build()
yield {"text": text}

ds = Dataset.from_generator(gen)
return ds


def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None:
cuda_available = torch.cuda.is_available()
config = {}

# train using 4 bit quantization for lower GPU memory usage
if cuda_available:
config.update({"device_map": "auto", "quantization_config": get_quantization_config()})

model = AutoModelForCausalLM.from_pretrained(
base_model,
**config,
)
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

peft_params = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=8,
bias="none",
task_type="CAUSAL_LM",
)

# All of these parameters are initial defaults and may need further tuning
training_params = TrainingArguments(
output_dir="./results",
num_train_epochs=4,
per_device_train_batch_size=1, # TODO: Increase once there's more data
gradient_accumulation_steps=1,
optim="paged_adamw_32bit" if cuda_available else "adamw_torch",
save_steps=25,
logging_steps=1, # TODO: Increase once there's more data
learning_rate=2e-4,
weight_decay=0.001,
fp16=cuda_available,
bf16=False,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type="constant",
report_to="tensorboard",
)

# Use a Supervised Fine-Tuning Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_params,
dataset_text_field="text",
max_seq_length=1024,
tokenizer=tokenizer,
args=training_params,
packing=False,
)

trainer.train()
trainer.model.save_pretrained(new_model_name)
trainer.tokenizer.save_pretrained(new_model_name)
32 changes: 14 additions & 18 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from typing import Dict, Iterable, List, Tuple
from typing import Any, Dict, List, Tuple

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from autora.doc.pipelines.data import preprocess_code
from autora.doc.runtime.prompts import CODE_PLACEHOLDER, LLAMA2_INST_CLOSE

logger = logging.getLogger(__name__)
Expand All @@ -13,16 +14,6 @@
non_quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "autora-doc/Llama-2-7b-chat-hf"}


def preprocess_code(code: str) -> str:
lines: Iterable[str] = code.splitlines()
skip_starts = {"import", "from", "#"}
lines = filter(
lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""),
lines,
)
return "\n".join(lines)


class Predictor:
def __init__(self, input_model_path: str):
model_path, config = Predictor.get_config(input_model_path)
Expand Down Expand Up @@ -93,7 +84,6 @@ def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]:
def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
if torch.cuda.is_available():
logger.info("CUDA is available, attempting to load quantized model")
from transformers import BitsAndBytesConfig

config = {"device_map": "auto"}
mapped_path = quantized_models.get(model_path, None)
Expand All @@ -102,14 +92,20 @@ def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
return mapped_path, config

# Load the model in 4bit quantization for faster inference on smaller GPUs
config["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
config["quantization_config"] = get_quantization_config()
return model_path, config
else:
logger.info("CUDA is not available, loading non-quantized model")
mapped_path = non_quantized_models.get(model_path, model_path)
return mapped_path, {}


def get_quantization_config() -> Any:
from transformers import BitsAndBytesConfig

return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
Loading

0 comments on commit aac9880

Please sign in to comment.