diff --git a/README.md b/README.md index 6db068f..2ff8a62 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,9 @@ [![Template](https://img.shields.io/badge/Template-LINCC%20Frameworks%20Python%20Project%20Template-brightgreen)](https://lincc-ppt.readthedocs.io/en/latest/) [![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/) -[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/AutoResearch/autodoc/smoke-test.yml)](https://github.com/AutoResearch/autodoc/actions/workflows/smoke-test.yml) + + +[![GitHub Workflow Status](https://github.com/autoresearch/autodoc/actions/workflows/smoke-test.yml/badge.svg)](https://github.com/AutoResearch/autodoc/actions/workflows/smoke-test.yml) [![codecov](https://codecov.io/gh/AutoResearch/autodoc/branch/main/graph/badge.svg)](https://codecov.io/gh/AutoResearch/autodoc) [![Read the Docs](https://img.shields.io/readthedocs/autora-doc)](https://autora-doc.readthedocs.io/) diff --git a/azureml/predict.yml b/azureml/predict.yml index 7f888b4..d5410a2 100644 --- a/azureml/predict.yml +++ b/azureml/predict.yml @@ -1,5 +1,10 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json -command: python -m autora.doc.pipelines.main predict ${{inputs.data_dir}}/data.jsonl ${{inputs.model_dir}}/llama-2-7b-chat-hf +command: > + python -m autora.doc.pipelines.main predict + ${{inputs.data_dir}}/data.jsonl + ${{inputs.model_dir}}/llama-2-7b-chat-hf + SYS_1 + INSTR_SWEETP_1 code: ../src inputs: data_dir: diff --git a/pyproject.toml b/pyproject.toml index e6db6a3..97c9c31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,3 +98,6 @@ include = ["src/autora"] [tool.hatch.build.targets.wheel] packages = ["src/autora"] + +[project.scripts] +autodoc = "autora.doc.pipelines.main:app" \ No newline at end of file diff --git a/src/autora/doc/example_module.py b/src/autora/doc/example_module.py deleted file mode 100644 index f76e837..0000000 --- a/src/autora/doc/example_module.py +++ /dev/null @@ -1,23 +0,0 @@ -"""An example module containing simplistic functions.""" - - -def greetings() -> str: - """A friendly greeting for a future friend. - - Returns - ------- - str - A typical greeting from a software engineer. - """ - return "Hello from LINCC-Frameworks!" - - -def meaning() -> int: - """The meaning of life, the universe, and everything. - - Returns - ------- - int - The meaning of life. - """ - return 42 diff --git a/src/autora/doc/pipelines/main.py b/src/autora/doc/pipelines/main.py index 292c8ff..b74bf4b 100644 --- a/src/autora/doc/pipelines/main.py +++ b/src/autora/doc/pipelines/main.py @@ -7,6 +7,7 @@ import typer from autora.doc.runtime.predict_hf import Predictor +from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts app = typer.Typer() logging.basicConfig( @@ -15,21 +16,13 @@ ) logger = logging.getLogger(__name__) -# TODO: organize the system and instruction prompts into a separate module -SYS = """You are a technical documentation writer. You always write clear, concise, and accurate documentation for - scientific experiments. Your documentation focuses on the experiment's purpose, procedure, and results. Therefore, - details about specific python functions, packages, or libraries are not necessary. Your readers are experimental - scientists. -""" - -instr = """Please generate high-level two paragraph documentation for the following experiment. The first paragraph - should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'""" - @app.command() -def predict(data_file: str, model_path: str) -> None: +def predict(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> None: run = mlflow.active_run() + sys_prompt = SYS[sys_id] + instr_prompt = INSTR[instruc_id] if run is None: run = mlflow.start_run() with run: @@ -45,7 +38,7 @@ def predict(data_file: str, model_path: str) -> None: pred = Predictor(model_path) timer_start = timer() - predictions = pred.predict(SYS, instr, inputs) + predictions = pred.predict(sys_prompt, instr_prompt, inputs) timer_end = timer() pred_time = timer_end - timer_start mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs))) diff --git a/src/autora/doc/runtime/predict_hf.py b/src/autora/doc/runtime/predict_hf.py index ba3e59d..cbde760 100644 --- a/src/autora/doc/runtime/predict_hf.py +++ b/src/autora/doc/runtime/predict_hf.py @@ -3,24 +3,22 @@ import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers import AutoModelForCausalLM, AutoTokenizer + +from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE, TEMP_LLAMA2 logger = logging.getLogger(__name__) class Predictor: def __init__(self, model_path: str): - # Load the model in 4bit quantization for faster inference on smaller GPUs - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - ) + config = self.get_config() + logger.info(f"Loading model from {model_path}") self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained( - model_path, quantization_config=bnb_config, device_map="auto" + model_path, + **config, ) logger.info("Model loaded") self.pipeline = transformers.pipeline( @@ -30,18 +28,8 @@ def __init__(self, model_path: str): ) def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]: - # Standard Llama2 template - template = f""" -[INST]<> -{sys} - -{instr} - -[INPUT] -[/INST] -""" logger.info(f"Generating {len(inputs)} predictions") - prompts = [template.replace("[INPUT]", input) for input in inputs] + prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs] # TODO: Make these parameters configurable sequences = self.pipeline( prompts, @@ -54,10 +42,35 @@ def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]: max_length=1000, ) - results = [sequence[0]["generated_text"] for sequence in sequences] + results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences] logger.info(f"Generated {len(results)} results") return results + @staticmethod + def trim_prompt(output: str) -> str: + marker = output.find(LLAMA2_INST_CLOSE) + if marker == -1: + logger.warning(f"Could not find end of prompt marker '{LLAMA2_INST_CLOSE}' in '{output}'") + return output + return output[marker + len(LLAMA2_INST_CLOSE) :] + def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]: tokens: Dict[str, List[List[int]]] = self.tokenizer(input) return tokens + + def get_config(self) -> Dict[str, str]: + if torch.cuda.is_available(): + from transformers import BitsAndBytesConfig + + # Load the model in 4bit quantization for faster inference on smaller GPUs + return { + "quantization_config": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ), + "device_map": "auto", + } + else: + return {} diff --git a/src/autora/doc/runtime/prompts.py b/src/autora/doc/runtime/prompts.py new file mode 100644 index 0000000..5875127 --- /dev/null +++ b/src/autora/doc/runtime/prompts.py @@ -0,0 +1,37 @@ +from enum import Enum + +LLAMA2_INST_CLOSE = "[/INST]\n" + +# Standard Llama2 template +TEMP_LLAMA2 = """ +[INST]<> +{sys} + +{instr} + +{input} +[/INST] +""" + + +SYS_1 = """You are a technical documentation writer. You always write clear, concise, and accurate documentation for +scientific experiments. Your documentation focuses on the experiment's purpose, procedure, and results. Therefore, +details about specific python functions, packages, or libraries are not necessary. Your readers are experimental +scientists. +""" + +INSTR_SWEETP_1 = """Please generate high-level two paragraph documentation for the following experiment. The first +paragraph should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'""" + + +class SystemPrompts(Enum): + SYS_1 = "SYS_1" + + +class InstructionPrompts(Enum): + SYS_1 = "SYS_1" + INSTR_SWEETP_1 = "INSTR_SWEETP_1" + + +SYS = {SystemPrompts.SYS_1: SYS_1} +INSTR = {InstructionPrompts.INSTR_SWEETP_1: INSTR_SWEETP_1} diff --git a/tests/test.py b/tests/test.py index a578227..9ad2c47 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,13 +1,16 @@ -from autora.doc import example_module +from autora.doc.runtime.predict_hf import Predictor -def test_greetings() -> None: - """Verify the output of the `greetings` function""" - output = example_module.greetings() - assert output == "Hello from LINCC-Frameworks!" +def test_trim_prompt() -> None: + """Verify the output of the `trim_prompt` function""" + no_marker = "Generated text with no marker" + output = Predictor.trim_prompt(no_marker) + assert output == no_marker - -def test_meaning() -> None: - """Verify the output of the `meaning` function""" - output = example_module.meaning() - assert output == 42 + with_marker = """ +The prompt is here +[/INST] +output +""" + output = Predictor.trim_prompt(with_marker) + assert output == "output\n"