Skip to content

Commit

Permalink
Refine inference code and add alternative inference with vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
binkjakub committed Dec 12, 2024
1 parent 0387b7a commit 3059ce3
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 8 deletions.
6 changes: 3 additions & 3 deletions configs/predict.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
defaults:
- model: ???
- dataset: pl-court-instruct
- dataset: ???
- _self_
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

model:
batch_size: 4
batch_size: 1

device_map: 'auto'
output_file: ???
output_file: data/experiments/predict/${hydra:runtime.choices.dataset}/${hydra:runtime.choices.model}/outputs_${random_seed}.json
truncate_context: True
generate_kwargs:
max_new_tokens: ${dataset.max_output_tokens}
Expand Down
11 changes: 10 additions & 1 deletion juddges/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import torch
from loguru import logger
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

Expand Down Expand Up @@ -84,11 +85,19 @@ def _get_model_tokenizer(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
if torch.cuda.is_available():
kwargs["attn_implementation"] = "flash_attention_2"

model = AutoModelForCausalLM.from_pretrained(llm_config.name, **kwargs)
model = AutoModelForCausalLM.from_pretrained(
llm_config.name,
torch_dtype="auto",
**kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(llm_config.name)

if llm_config.adapter_path is not None:
logger.info(f"Loading adapter from {llm_config.adapter_path}")
model = PeftModel.from_pretrained(model, llm_config.adapter_path)
model = model.merge_and_unload(safe_merge=True)

return model, tokenizer
2 changes: 1 addition & 1 deletion juddges/preprocessing/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, batch: dict[str, list[Any]]) -> dict[str, Tensor]:
return tokenized


class TextEncoderForOpenAIEval:
class TextEncoderForEvalPlainTextFormat:
def __init__(self, truncator: ContextTruncatorTiktoken):
self.truncator = truncator

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ trl==0.12.2
typer==0.9.0
wandb==0.19.0
weaviate-client==4.8.1
vllm==0.6.4.post1
xmltodict==0.13.0
xlsxwriter==3.2.0

Expand Down
13 changes: 12 additions & 1 deletion scripts/sft/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from juddges.utils.config import resolve_config
from juddges.utils.misc import sort_dataset_by_input_length

torch.set_float32_matmul_precision("high")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_PROC = int(os.getenv("NUM_PROC", 1))

Expand All @@ -45,6 +46,16 @@ def corrected_max_seq_length(self) -> int:
@hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="predict.yaml")
@torch.inference_mode()
def main(cfg: DictConfig) -> None:
"""Performs inference on a given dataset using given model.
The outputs are saved to a file in the following JSONL format:
[
{
"answer": str,
"gold": str
},
...
]
"""
config = PredictConfig(**resolve_config(cfg))
logger.info(f"config:\n{pformat(config.model_dump())}")

Expand All @@ -65,8 +76,8 @@ def main(cfg: DictConfig) -> None:

FastLanguageModel.for_inference(model_pack.model)
else:
model_pack.model.compile()
model_pack.model.eval()
model_pack.model.compile()

if config.model.batch_size > 1 and config.model.padding is False:
raise ValueError("Padding has to be enabled if batch size > 1.")
Expand Down
86 changes: 86 additions & 0 deletions scripts/sft/predict_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import os
from pathlib import Path
from pprint import pformat
from typing import Any

import hydra
import torch
from datasets import load_dataset
from loguru import logger
from omegaconf import DictConfig
from openai import BaseModel
from pydantic import Field
from vllm import LLM, SamplingParams

from juddges.config import DatasetConfig, LLMConfig
from juddges.preprocessing.context_truncator import ContextTruncator
from juddges.preprocessing.text_encoder import TextEncoderForEvalPlainTextFormat
from juddges.settings import CONFIG_PATH
from juddges.utils.config import resolve_config

torch.set_float32_matmul_precision("high")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_PROC = int(os.getenv("NUM_PROC", 1))


class PredictConfig(BaseModel, extra="forbid"):
model: LLMConfig
dataset: DatasetConfig
device_map: str
output_file: Path
truncate_context: bool
generate_kwargs: dict[str, Any] = Field(default_factory=dict)
random_seed: int

@property
def corrected_max_seq_length(self) -> int:
return self.model.max_seq_length - self.dataset.max_output_tokens


@hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="predict.yaml")
@torch.inference_mode()
def main(cfg: DictConfig) -> None:
config = PredictConfig(**resolve_config(cfg))
logger.info(f"config:\n{pformat(config.model_dump())}")

output_file = Path(config.output_file)
output_file.parent.mkdir(parents=True, exist_ok=True)

ds = load_dataset(config.dataset.name, split="test")

llm = LLM(
model=config.model.name,
quantization="bitsandbytes",
load_format="bitsandbytes",
enable_lora=True,
qlora_adapter_name_or_path=config.model.adapter_path,
max_model_len=config.model.max_seq_length,
max_num_seqs=config.model.batch_size,
)

truncator = ContextTruncator(
tokenizer=llm.llm_engine.tokenizer.get_lora_tokenizer(),
max_length=config.corrected_max_seq_length,
)
encoder = TextEncoderForEvalPlainTextFormat(truncator=truncator)
ds = ds.map(encoder, num_proc=NUM_PROC)
ds = ds.select(range(10))

params = SamplingParams(
max_tokens=config.generate_kwargs.get("max_new_tokens", 100),
temperature=config.generate_kwargs.get("temperature", 0.0),
)

outputs = llm.generate(
prompts=ds["final_input"],
sampling_params=params,
)
results = [{"answer": ans, "gold": gold} for ans, gold in zip(outputs, ds["output"])]

with open(output_file, "w") as f:
json.dump(results, f, indent="\t", ensure_ascii=False)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions scripts/sft/predict_with_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from juddges.config import DatasetConfig
from juddges.preprocessing.context_truncator import ContextTruncatorTiktoken
from juddges.preprocessing.text_encoder import TextEncoderForOpenAIEval
from juddges.preprocessing.text_encoder import TextEncoderForEvalPlainTextFormat
from juddges.settings import CONFIG_PATH
from juddges.utils.config import resolve_config

Expand Down Expand Up @@ -55,7 +55,7 @@ def main(cfg: DictConfig) -> None:
gold_outputs = [item["output"] for item in ds]

truncator = ContextTruncatorTiktoken(model=config.model_version, max_length=config.max_seq_len)
encoder = TextEncoderForOpenAIEval(truncator=truncator)
encoder = TextEncoderForEvalPlainTextFormat(truncator=truncator)
ds = ds.map(encoder, num_proc=NUM_PROC)

predictor = OpenAIPredictor(config=config)
Expand Down
88 changes: 88 additions & 0 deletions scripts/utils/convert_deepspeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import shutil
import subprocess
from pathlib import Path
from pprint import pformat

import typer
from loguru import logger
from tqdm import tqdm

SCRIPT_NAME = "zero_to_fp32.py"
CONVERTED_MODEL_PATTERN = "model*.safetensors"


def main(
root_dir: Path = typer.Option(),
adapter_only: bool = typer.Option(False, help="Only convert adapter"),
remove: bool = typer.Option(False, help="Removes original deepspeed checkpoints"),
remove_for_converted: bool = typer.Option(
False, help="Removes only original deepspeed checkpoints for already converted models"
),
) -> None:
checkpoint_dirs = [script_file.parent for script_file in root_dir.rglob(SCRIPT_NAME)]
logger.info(f"Found {len(checkpoint_dirs)} checkpoints to convert:\n{pformat(checkpoint_dirs)}")
for ckpt_dir in tqdm(checkpoint_dirs, desc="Converting checkpoints"):
logger.info(f"Converting {ckpt_dir}")
if list(ckpt_dir.glob(CONVERTED_MODEL_PATTERN)):
logger.warning(f"Model already converted, skipping {ckpt_dir}")
if remove_for_converted:
logger.info(f"Removing deepspeed artifacts for {ckpt_dir}")
remove_deepspeed_artifacts(ckpt_dir)
continue
else:
convert(ckpt_dir)

# deepspeed saves model as model.safetensors, need to rename it to adapter_model.safetensors
if adapter_only:
# there should be (almost) empty adapter_model.safetensors
assert (ckpt_dir / "adapter_model.safetensors").exists()
for model_file in ckpt_dir.glob("model*.safetensors"):
model_file.rename(
model_file.with_stem(model_file.stem.replace("model", "adapter_model"))
)

if remove:
remove_deepspeed_artifacts(ckpt_dir)


def convert(ckpt_dir: Path) -> None:
script_file = ckpt_dir / SCRIPT_NAME
step_dir = get_latest_step_dir(ckpt_dir)
logger.info(f"Converting {step_dir}")
cmd = [
"python",
str(script_file),
str(ckpt_dir), # checkpoint_dir
str(ckpt_dir), # output_dir
"--safe_serialization", # writes as safetensors file
"--max_shard_size",
"5GB",
"--tag",
step_dir.name, # points to directory globalstep<step_num>
]
env = os.environ.copy() | {"CUDA_VISIBLE_DEVICES": "-1"}
subprocess.run(cmd, check=True, env=env)


def remove_deepspeed_artifacts(ckpt_dir: Path) -> None:
step_dir = get_latest_step_dir(ckpt_dir)
logger.info(f"Removing {step_dir}")
shutil.rmtree(step_dir)

for rng_file in ckpt_dir.glob("rng_state_*.pth"):
os.remove(rng_file)

os.remove(ckpt_dir / SCRIPT_NAME)
os.remove(ckpt_dir / "latest")
os.remove(ckpt_dir / "scheduler.pt")


def get_latest_step_dir(ckpt_dir: Path) -> Path:
with open(ckpt_dir / "latest") as f:
step_dirname = f.read().strip()
return ckpt_dir / step_dirname


if __name__ == "__main__":
typer.run(main)

0 comments on commit 3059ce3

Please sign in to comment.