Skip to content

Commit

Permalink
feat: Generate command
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Dec 6, 2023
1 parent 4f8d900 commit 424f8b6
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ az storage blob upload --account-name <account> --container <container>> --file

Prediction
```sh
az ml job create -f azureml/predict.yml --set display_name="Test prediction job" --web
az ml job create -f azureml/eval.yml --set display_name="Test prediction job" --web
```

Notes:
Expand Down
2 changes: 1 addition & 1 deletion azureml/predict.yml → azureml/eval.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
command: >
python -m autora.doc.pipelines.main predict
python -m autora.doc.pipelines.main eval
${{inputs.data_dir}}/data.jsonl
${{inputs.model_dir}}/llama-2-7b-chat-hf
SYS_1
Expand Down
18 changes: 18 additions & 0 deletions azureml/generate.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
command: >
python -m autora.doc.pipelines.main generate
--model-path ${{inputs.model_dir}}/llama-2-7b-chat-hf
--output ./outputs/output.txt
autora/doc/pipelines/main.py
code: ../src
inputs:
model_dir:
type: uri_folder
path: azureml://datastores/workspaceblobstore/paths/base_models
environment:
image: mcr.microsoft.com/azureml/curated/acpt-pytorch-2.0-cuda11.7:21
conda_file: conda.yml
display_name: autodoc_prediction
compute: azureml:v100cluster
experiment_name: autodoc_prediction
description: |
25 changes: 21 additions & 4 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@


@app.command()
def predict(
data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts
) -> List[str]:
def eval(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> List[str]:
run = mlflow.active_run()

sys_prompt = SYS[sys_id]
Expand All @@ -33,7 +31,6 @@ def predict(
logger.info(f"running predict with {data_file}")
logger.info(f"model path: {model_path}")

# predictions = []
with jsonlines.open(data_file) as reader:
items = [item for item in reader]
inputs = [item["instruction"] for item in items]
Expand All @@ -57,6 +54,26 @@ def predict(
return predictions


@app.command()
def generate(
python_file: str,
model_path: str = "meta-llama/llama-2-7b-chat-hf",
output: str = "output.txt",
sys_id: SystemPrompts = SystemPrompts.SYS_1,
instruc_id: InstructionPrompts = InstructionPrompts.INSTR_SWEETP_1,
) -> None:
with open(python_file, "r") as f:
inputs = [f.read()]
sys_prompt = SYS[sys_id]
instr_prompt = INSTR[instruc_id]
pred = Predictor(model_path)
predictions = pred.predict(sys_prompt, instr_prompt, inputs)
assert len(predictions) == 1, f"Expected only one output, got {len(predictions)}"
logger.info(f"Writing output to {output}")
with open(output, "w") as f:
f.write(predictions[0])


@app.command()
def import_model(model_name: str) -> None:
pass
Expand Down
2 changes: 1 addition & 1 deletion src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def predict(self, sys: str, instr: str, inputs: List[str]) -> List[str]:
top_k=40,
num_return_sequences=1,
eos_token_id=self.tokenizer.eos_token_id,
max_length=1000,
max_length=2048,
)

results = [Predictor.trim_prompt(sequence[0]["generated_text"]) for sequence in sequences]
Expand Down
4 changes: 2 additions & 2 deletions src/autora/doc/runtime/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
paragraph should explain the purpose and the second one the procedure, but don't use the word 'Paragraph'"""


class SystemPrompts(Enum):
class SystemPrompts(str, Enum):
SYS_1 = "SYS_1"


class InstructionPrompts(Enum):
class InstructionPrompts(str, Enum):
INSTR_SWEETP_1 = "INSTR_SWEETP_1"


Expand Down
14 changes: 12 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from autora.doc.pipelines.main import predict
from autora.doc.pipelines.main import eval, generate
from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts

# dummy HF model for testing
Expand All @@ -9,7 +9,17 @@

def test_predict() -> None:
data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve()
outputs = predict(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
outputs = eval(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
assert len(outputs) == 3, "Expected 3 outputs"
for output in outputs:
assert len(output) > 0, "Expected non-empty output"


def test_generate() -> None:
python_file = __file__
output = Path("output.txt")
output.unlink(missing_ok=True)
generate(python_file, TEST_HF_MODEL, str(output), SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
assert output.exists(), f"Expected output file {output} to exist"
with open(str(output), "r") as f:
assert len(f.read()) > 0, f"Expected non-empty output file {output}"

0 comments on commit 424f8b6

Please sign in to comment.