Skip to content

Commit

Permalink
refactor: Refactor prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Dec 5, 2023
1 parent 600b53f commit fcc744d
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 68 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
[![Template](https://img.shields.io/badge/Template-LINCC%20Frameworks%20Python%20Project%20Template-brightgreen)](https://lincc-ppt.readthedocs.io/en/latest/)

[![PyPI](https://img.shields.io/pypi/v/autora-doc?color=blue&logo=pypi&logoColor=white)](https://pypi.org/project/autora-doc/)
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/AutoResearch/autodoc/smoke-test.yml)](https://github.com/AutoResearch/autodoc/actions/workflows/smoke-test.yml)


[![GitHub Workflow Status](https://github.com/autoresearch/autodoc/actions/workflows/smoke-test.yml/badge.svg)](https://github.com/AutoResearch/autodoc/actions/workflows/smoke-test.yml)
[![codecov](https://codecov.io/gh/AutoResearch/autodoc/branch/main/graph/badge.svg)](https://codecov.io/gh/AutoResearch/autodoc)
[![Read the Docs](https://img.shields.io/readthedocs/autora-doc)](https://autora-doc.readthedocs.io/)

Expand Down
7 changes: 6 additions & 1 deletion azureml/predict.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
command: python -m autora.doc.pipelines.main predict ${{inputs.data_dir}}/data.jsonl ${{inputs.model_dir}}/llama-2-7b-chat-hf
command: >
python -m autora.doc.pipelines.main predict
${{inputs.data_dir}}/data.jsonl
${{inputs.model_dir}}/llama-2-7b-chat-hf
SYS_1
INSTR_SWEETP_1
code: ../src
inputs:
data_dir:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ include = ["src/autora"]

[tool.hatch.build.targets.wheel]
packages = ["src/autora"]

[project.scripts]
autodoc = "autora.doc.pipelines.main:app"
23 changes: 0 additions & 23 deletions src/autora/doc/example_module.py

This file was deleted.

17 changes: 5 additions & 12 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typer

from autora.doc.runtime.predict_hf import Predictor
from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts

app = typer.Typer()
logging.basicConfig(
Expand All @@ -15,21 +16,13 @@
)
logger = logging.getLogger(__name__)

# TODO: organize the system and instruction prompts into a separate module
SYS = """You are a technical documentation writer. You always write clear, concise, and accurate documentation for
scientific experiments. Your documentation focuses on the experiment's purpose, procedure, and results. Therefore,
details about specific python functions, packages, or libraries are not necessary. Your readers are experimental
scientists.
"""

instr = """Please generate high-level two paragraph documentation for the following experiment. The first paragraph
should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'"""


@app.command()
def predict(data_file: str, model_path: str) -> None:
def predict(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> None:
run = mlflow.active_run()

sys_prompt = SYS[sys_id]
instr_prompt = INSTR[instruc_id]
if run is None:
run = mlflow.start_run()
with run:
Expand All @@ -45,7 +38,7 @@ def predict(data_file: str, model_path: str) -> None:

pred = Predictor(model_path)
timer_start = timer()
predictions = pred.predict(SYS, instr, inputs)
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
timer_end = timer()
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
Expand Down
55 changes: 34 additions & 21 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

from autora.doc.runtime.prompts import LLAMA2_INST_CLOSE, TEMP_LLAMA2

logger = logging.getLogger(__name__)


class Predictor:
def __init__(self, model_path: str):
# Load the model in 4bit quantization for faster inference on smaller GPUs
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
config = self.get_config()

logger.info(f"Loading model from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, quantization_config=bnb_config, device_map="auto"
model_path,
**config,
)
logger.info("Model loaded")
self.pipeline = transformers.pipeline(
Expand All @@ -30,18 +28,8 @@ def __init__(self, model_path: str):
)

def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
# Standard Llama2 template
template = f"""
[INST]<<SYS>>
{sys}
{instr}
[INPUT]
[/INST]
"""
logger.info(f"Generating {len(inputs)} predictions")
prompts = [template.replace("[INPUT]", input) for input in inputs]
prompts = [TEMP_LLAMA2.format(sys=sys, instr=instr, input=input) for input in inputs]
# TODO: Make these parameters configurable
sequences = self.pipeline(
prompts,
Expand All @@ -54,10 +42,35 @@ def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
max_length=1000,
)

results = [sequence[0]["generated_text"] for sequence in sequences]
results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences]
logger.info(f"Generated {len(results)} results")
return results

@staticmethod
def trim_prompt(output: str) -> str:
marker = output.find(LLAMA2_INST_CLOSE)
if marker == -1:
logger.warning(f"Could not find end of prompt marker '{LLAMA2_INST_CLOSE}' in '{output}'")
return output
return output[marker + len(LLAMA2_INST_CLOSE) :]

def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]:
tokens: Dict[str, List[List[int]]] = self.tokenizer(input)
return tokens

def get_config(self) -> Dict[str, str]:
if torch.cuda.is_available():
from transformers import BitsAndBytesConfig

# Load the model in 4bit quantization for faster inference on smaller GPUs
return {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
),
"device_map": "auto",
}
else:
return {}
37 changes: 37 additions & 0 deletions src/autora/doc/runtime/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from enum import Enum

LLAMA2_INST_CLOSE = "[/INST]\n"

# Standard Llama2 template
TEMP_LLAMA2 = """
[INST]<<SYS>>
{sys}
{instr}
{input}
[/INST]
"""


SYS_1 = """You are a technical documentation writer. You always write clear, concise, and accurate documentation for
scientific experiments. Your documentation focuses on the experiment's purpose, procedure, and results. Therefore,
details about specific python functions, packages, or libraries are not necessary. Your readers are experimental
scientists.
"""

INSTR_SWEETP_1 = """Please generate high-level two paragraph documentation for the following experiment. The first
paragraph should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'"""


class SystemPrompts(Enum):
SYS_1 = "SYS_1"


class InstructionPrompts(Enum):
SYS_1 = "SYS_1"
INSTR_SWEETP_1 = "INSTR_SWEETP_1"


SYS = {SystemPrompts.SYS_1: SYS_1}
INSTR = {InstructionPrompts.INSTR_SWEETP_1: INSTR_SWEETP_1}
23 changes: 13 additions & 10 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from autora.doc import example_module
from autora.doc.runtime.predict_hf import Predictor


def test_greetings() -> None:
"""Verify the output of the `greetings` function"""
output = example_module.greetings()
assert output == "Hello from LINCC-Frameworks!"
def test_trim_prompt() -> None:
"""Verify the output of the `trim_prompt` function"""
no_marker = "Generated text with no marker"
output = Predictor.trim_prompt(no_marker)
assert output == no_marker


def test_meaning() -> None:
"""Verify the output of the `meaning` function"""
output = example_module.meaning()
assert output == 42
with_marker = """
The prompt is here
[/INST]
output
"""
output = Predictor.trim_prompt(with_marker)
assert output == "output\n"

0 comments on commit fcc744d

Please sign in to comment.