diff --git a/.gitignore b/.gitignore index cce3a02..1caceb6 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,6 @@ _html/ # mlflow output mlruns/ + +#tensorflow output +results/ diff --git a/.mypy.ini b/.mypy.ini index 9bf3cdf..e260c8b 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -14,3 +14,15 @@ ignore_missing_imports = True [mypy-sentence_transformers.*] ignore_missing_imports = True + +[mypy-datasets.*] +ignore_missing_imports = True + +[mypy-trl.*] +ignore_missing_imports = True + +[mypy-peft.*] +implicit_reexport = True + +[mypy-numba.*] +ignore_missing_imports = True diff --git a/README.md b/README.md index de851a7..8f0ab11 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,22 @@ Notes: the Python Project Template documentation on [Sphinx and Python Notebooks](https://lincc-ppt.readthedocs.io/en/latest/practices/sphinx.html#python-notebooks) +## Models + +The models are hosted in the [autora-doc](https://huggingface.co/autora-doc) Huggingface organization. + +## Usage + +Once the package is installed, documentation can be generated through the `autodoc` CLI tool: + +```sh +autodoc generate +``` + +## Running on Colab + +A notebook for testing different prompts can be run on Google Colab through [this link](https://colab.research.google.com/github/AutoResearch/autodoc/blob/main/notebooks/generate.ipynb). Be sure to change the Runtime type to a T4 GPU. + ## Running AzureML pipelines @@ -73,16 +89,25 @@ az configure --defaults workspace= group= locatio ### Running jobs -Prediction +Inference ```sh -az ml job create -f azureml/eval.yml --set display_name="Test prediction job" --set environment_variables.HF_TOKEN= --web +az ml job create -f azureml/generate.yml --set display_name="Test inference job" ``` -Notes: +Evaluation +```sh +az ml job create -f azureml/eval.yml --set display_name="Test evaluation job" +``` + +Fine-Tuning (training) +```sh +az ml job create -f azureml/train.yml --set display_name="Test training job" +``` + +Additional arguments: - `--name` will set the mlflow run id - `--display_name` becomes the name in the experiment dashboard - `--web` argument will pop-up a browser window for tracking the job. -- The `HF_TOKEN` is required for gated repos, which need authentication ### Uploading data diff --git a/azureml/conda.yml b/azureml/conda.yml index 32ed227..bbaf766 100644 --- a/azureml/conda.yml +++ b/azureml/conda.yml @@ -18,3 +18,7 @@ dependencies: # This works, while installing from pytorch and cuda from conda does not - torch==2.0.1 - sentence_transformers>=2.3.1 + - datasets + - peft>=0.8.2 + - trl>=0.7.10 + - tensorboardX diff --git a/azureml/train.yml b/azureml/train.yml new file mode 100644 index 0000000..4db974b --- /dev/null +++ b/azureml/train.yml @@ -0,0 +1,30 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json +command: > + python -m autora.doc.pipelines.main train + ${{inputs.new_model_name}} + ${{inputs.data_dir}}/data.jsonl + --base-model ${{inputs.model_path}} +code: ../src +inputs: + data_dir: + type: uri_folder + path: azureml://datastores/workspaceblobstore/paths/data/autora + model_path: autora-doc/Llama-2-7b-chat-hf-nf4 + new_model_name: autora-doc/Llama-2-7b-chat-hf-nf4-ft +environment_variables: + PYTORCH_CUDA_ALLOC_CONF: max_split_size_mb:128 +# using a curated environment doesn't work because we need additional packages +environment: # azureml://registries/azureml/environments/acpt-pytorch-2.0-cuda11.7/versions/21 + image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21 + # These didn't work + # image: mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu117-py38-torch201:biweekly.202310.3 + # image: mcr.microsoft.com/azureml/curated/acpt-pytorch-1.13-cuda11.7:latest + # image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.0.3-cudnn8-ubuntu18.04 + # image: mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04 + # image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.6-cudnn8-ubuntu20.04 + # image: nvcr.io/nvidia/pytorch:23.10-py3 + conda_file: conda.yml +display_name: autodoc_train +compute: azureml:v100cluster +experiment_name: train +description: | diff --git a/pyproject.toml b/pyproject.toml index d8a335e..b21d605 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,16 @@ dev = [ "ipykernel", "hf_transfer", ] -pipelines = ["jsonlines", "mlflow", "nltk", "sentence-transformers>=2.3.1"] +pipelines = [ + "jsonlines", + "mlflow", + "nltk", + "sentence-transformers>=2.3.1", + "peft>=0.8.2", + "trl>=0.7.10", + "datasets", + "tensorboardX", +] # NOTE: When updating dependencies, in particular cuda/azure ml, make sure to update the azureml/conda.yaml too azure = ["azureml-core", "azureml-mlflow"] cuda = ["bitsandbytes>=0.42.0", "accelerate>=0.24.1", "xformers"] diff --git a/src/autora/doc/pipelines/data.py b/src/autora/doc/pipelines/data.py new file mode 100644 index 0000000..e067297 --- /dev/null +++ b/src/autora/doc/pipelines/data.py @@ -0,0 +1,21 @@ +from typing import Iterable, List, Tuple + + +def load_data(data_file: str) -> Tuple[List[str], List[str]]: + import jsonlines + + with jsonlines.open(data_file) as reader: + items = [item for item in reader] + inputs = [item["instruction"] for item in items] + labels = [item["output"] for item in items] + return inputs, labels + + +def preprocess_code(code: str) -> str: + lines: Iterable[str] = code.splitlines() + skip_starts = {"import", "from", "#"} + lines = filter( + lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""), + lines, + ) + return "\n".join(lines) diff --git a/src/autora/doc/pipelines/main.py b/src/autora/doc/pipelines/main.py index 5800ec2..56a066c 100644 --- a/src/autora/doc/pipelines/main.py +++ b/src/autora/doc/pipelines/main.py @@ -1,17 +1,24 @@ import itertools import logging from timeit import default_timer as timer -from typing import Dict, List, Tuple +from typing import Dict, List import torch import typer from autora.doc.classes.EvalResult import EvalResult +from autora.doc.pipelines.data import load_data from autora.doc.pipelines.metrics import eval_bleu_meteor, eval_semscore +from autora.doc.pipelines.train import fine_tune, get_dataset from autora.doc.runtime.predict_hf import Predictor from autora.doc.runtime.prompts import PROMPTS, PromptIds from autora.doc.util import get_prompts_from_file +# For inference +DEFAULT_INFERENCE_MODEL = "meta-llama/Llama-2-7b-chat-hf" +# For training +DEFAULT_BASE_MODEL = "autora-doc/Llama-2-7b-chat-hf" + logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s", @@ -24,7 +31,7 @@ @app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file") def eval_prompts( data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"), - model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"), + model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"), prompts_file: str = typer.Argument(..., help="JSON file with a list of dictionary of prompts"), param: List[str] = typer.Option( [], help="Additional float parameters to pass to the model as name=float pairs" @@ -62,7 +69,7 @@ def eval_prompts( @app.command(help="Evaluate model on a data file") def eval( data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"), - model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"), + model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"), prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"), param: List[str] = typer.Option( [], help="Additional float parameters to pass to the model as name=float pairs" @@ -89,16 +96,6 @@ def eval( return eval_prompt(data_file, pred, prompt, param_dict) -def load_data(data_file: str) -> Tuple[List[str], List[str]]: - import jsonlines - - with jsonlines.open(data_file) as reader: - items = [item for item in reader] - inputs = [f"{item['instruction']}" for item in items] - labels = [item["output"] for item in items] - return inputs, labels - - def eval_prompt(data_file: str, pred: Predictor, prompt: str, param_dict: Dict[str, float]) -> EvalResult: import mlflow @@ -132,7 +129,7 @@ def eval_prompt(data_file: str, pred: Predictor, prompt: str, param_dict: Dict[s @app.command() def generate( python_file: str = typer.Argument(..., help="Python file to generate documentation for"), - model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"), + model_path: str = typer.Option(DEFAULT_INFERENCE_MODEL, help="Path to HF model"), output: str = typer.Option("output.txt", help="Output file"), prompt_id: PromptIds = typer.Option(PromptIds.SWEETP_1, help="Instruction prompt ID"), param: List[str] = typer.Option( @@ -160,6 +157,18 @@ def import_model(model_name: str) -> None: pass +@app.command() +def train( + new_model_name: str = typer.Argument(..., help="File name for the fine-tuned model"), + dataset: str = typer.Argument(..., help="Path to the jsonl file with training data"), + base_model: str = typer.Option( + DEFAULT_BASE_MODEL, help="Path to the base Huggingface model to fine-tune" + ), +) -> None: + ds = get_dataset(dataset) + fine_tune(base_model, new_model_name, ds) + + @app.command() def import_data(code_file: str, text_file: str, output_file: str = "data.jsonl") -> None: from pathlib import Path diff --git a/src/autora/doc/pipelines/train.py b/src/autora/doc/pipelines/train.py new file mode 100644 index 0000000..b001fa3 --- /dev/null +++ b/src/autora/doc/pipelines/train.py @@ -0,0 +1,89 @@ +from typing import Dict, Iterable + +import torch +from datasets import Dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments +from trl import SFTTrainer + +from autora.doc.pipelines.data import load_data, preprocess_code +from autora.doc.runtime.predict_hf import get_quantization_config +from autora.doc.runtime.prompts import INSTR_SWEETP_1, SYS_GUIDES, PromptBuilder + + +def get_dataset(jsonl_path: str) -> Dataset: + # "instruction", "output" + inputs, labels = load_data(jsonl_path) + + def gen() -> Iterable[Dict[str, str]]: + for i, o in zip(inputs, labels): + text = PromptBuilder(SYS_GUIDES, INSTR_SWEETP_1).add_example(preprocess_code(i), o).build() + yield {"text": text} + + ds = Dataset.from_generator(gen) + return ds + + +def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None: + cuda_available = torch.cuda.is_available() + config = {} + + # train using 4 bit quantization for lower GPU memory usage + if cuda_available: + config.update({"device_map": "auto", "quantization_config": get_quantization_config()}) + + model = AutoModelForCausalLM.from_pretrained( + base_model, + **config, + ) + model.config.use_cache = False + model.config.pretraining_tp = 1 + + tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + + peft_params = LoraConfig( + lora_alpha=16, + lora_dropout=0.05, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + + # All of these parameters are initial defaults and may need further tuning + training_params = TrainingArguments( + output_dir="./results", + num_train_epochs=4, + per_device_train_batch_size=1, # TODO: Increase once there's more data + gradient_accumulation_steps=1, + optim="paged_adamw_32bit" if cuda_available else "adamw_torch", + save_steps=25, + logging_steps=1, # TODO: Increase once there's more data + learning_rate=2e-4, + weight_decay=0.001, + fp16=cuda_available, + bf16=False, + max_grad_norm=0.3, + max_steps=-1, + warmup_ratio=0.03, + group_by_length=True, + lr_scheduler_type="constant", + report_to="tensorboard", + ) + + # Use a Supervised Fine-Tuning Trainer + trainer = SFTTrainer( + model=model, + train_dataset=dataset, + peft_config=peft_params, + dataset_text_field="text", + max_seq_length=1024, + tokenizer=tokenizer, + args=training_params, + packing=False, + ) + + trainer.train() + trainer.model.save_pretrained(new_model_name) + trainer.tokenizer.save_pretrained(new_model_name) diff --git a/src/autora/doc/runtime/predict_hf.py b/src/autora/doc/runtime/predict_hf.py index e1be754..d8bf424 100644 --- a/src/autora/doc/runtime/predict_hf.py +++ b/src/autora/doc/runtime/predict_hf.py @@ -1,10 +1,11 @@ import logging -from typing import Dict, Iterable, List, Tuple +from typing import Any, Dict, List, Tuple import torch import transformers from transformers import AutoModelForCausalLM, AutoTokenizer +from autora.doc.pipelines.data import preprocess_code from autora.doc.runtime.prompts import CODE_PLACEHOLDER, LLAMA2_INST_CLOSE logger = logging.getLogger(__name__) @@ -13,16 +14,6 @@ non_quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "autora-doc/Llama-2-7b-chat-hf"} -def preprocess_code(code: str) -> str: - lines: Iterable[str] = code.splitlines() - skip_starts = {"import", "from", "#"} - lines = filter( - lambda line: not (any([line.strip().startswith(skip) for skip in skip_starts]) or line.strip() == ""), - lines, - ) - return "\n".join(lines) - - class Predictor: def __init__(self, input_model_path: str): model_path, config = Predictor.get_config(input_model_path) @@ -93,7 +84,6 @@ def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]: def get_config(model_path: str) -> Tuple[str, Dict[str, str]]: if torch.cuda.is_available(): logger.info("CUDA is available, attempting to load quantized model") - from transformers import BitsAndBytesConfig config = {"device_map": "auto"} mapped_path = quantized_models.get(model_path, None) @@ -102,14 +92,20 @@ def get_config(model_path: str) -> Tuple[str, Dict[str, str]]: return mapped_path, config # Load the model in 4bit quantization for faster inference on smaller GPUs - config["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - ) + config["quantization_config"] = get_quantization_config() return model_path, config else: logger.info("CUDA is not available, loading non-quantized model") mapped_path = non_quantized_models.get(model_path, model_path) return mapped_path, {} + + +def get_quantization_config() -> Any: + from transformers import BitsAndBytesConfig + + return BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) diff --git a/tests/test_main.py b/tests/test_main.py index 2e9ba66..9ce4893 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,7 +2,7 @@ from typing import List from autora.doc.classes.EvalResult import EvalResult -from autora.doc.pipelines.main import eval, eval_prompts, generate, import_data +from autora.doc.pipelines.main import eval, eval_prompts, generate, import_data, train from autora.doc.runtime.prompts import PromptIds # dummy HF model for testing @@ -44,3 +44,9 @@ def test_eval_prompts() -> None: for result in results: assert result.predictions is not None, "The prediction should not be None" assert result.prompt is not None, "The prompt should not be None" + + +def test_train(tmp_path: Path) -> None: + dataset = Path(__file__).parent.joinpath("../data/autora/data.jsonl").resolve() + train(str(tmp_path), str(dataset), TEST_HF_MODEL) + assert tmp_path.joinpath("./adapter_model.safetensors").exists(), "Expected model to be trained"