diff --git a/configs/predict.yaml b/configs/predict.yaml index 43cf311..04ae08b 100644 --- a/configs/predict.yaml +++ b/configs/predict.yaml @@ -1,15 +1,15 @@ defaults: - model: ??? - - dataset: pl-court-instruct + - dataset: ??? - _self_ - override hydra/hydra_logging: disabled - override hydra/job_logging: disabled model: - batch_size: 4 + batch_size: 1 device_map: 'auto' -output_file: ??? +output_file: data/experiments/predict/${hydra:runtime.choices.dataset}/${hydra:runtime.choices.model}/outputs_${random_seed}.json truncate_context: True generate_kwargs: max_new_tokens: ${dataset.max_output_tokens} diff --git a/juddges/models/factory.py b/juddges/models/factory.py index 901ee13..5395e82 100644 --- a/juddges/models/factory.py +++ b/juddges/models/factory.py @@ -2,6 +2,7 @@ from typing import Any import torch +from loguru import logger from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig @@ -84,11 +85,19 @@ def _get_model_tokenizer( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, ) + if torch.cuda.is_available(): + kwargs["attn_implementation"] = "flash_attention_2" - model = AutoModelForCausalLM.from_pretrained(llm_config.name, **kwargs) + model = AutoModelForCausalLM.from_pretrained( + llm_config.name, + torch_dtype="auto", + **kwargs, + ) tokenizer = AutoTokenizer.from_pretrained(llm_config.name) if llm_config.adapter_path is not None: + logger.info(f"Loading adapter from {llm_config.adapter_path}") model = PeftModel.from_pretrained(model, llm_config.adapter_path) + model = model.merge_and_unload(safe_merge=True) return model, tokenizer diff --git a/juddges/preprocessing/text_encoder.py b/juddges/preprocessing/text_encoder.py index 42eb6c3..875f13c 100644 --- a/juddges/preprocessing/text_encoder.py +++ b/juddges/preprocessing/text_encoder.py @@ -46,7 +46,7 @@ def __call__(self, batch: dict[str, list[Any]]) -> dict[str, Tensor]: return tokenized -class TextEncoderForOpenAIEval: +class TextEncoderForEvalPlainTextFormat: def __init__(self, truncator: ContextTruncatorTiktoken): self.truncator = truncator diff --git a/requirements.txt b/requirements.txt index 59a74e9..18df9d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,6 +37,7 @@ trl==0.12.2 typer==0.9.0 wandb==0.19.0 weaviate-client==4.8.1 +vllm==0.6.4.post1 xmltodict==0.13.0 xlsxwriter==3.2.0 diff --git a/scripts/sft/predict.py b/scripts/sft/predict.py index cab27de..b611631 100644 --- a/scripts/sft/predict.py +++ b/scripts/sft/predict.py @@ -21,6 +21,7 @@ from juddges.utils.config import resolve_config from juddges.utils.misc import sort_dataset_by_input_length +torch.set_float32_matmul_precision("high") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" NUM_PROC = int(os.getenv("NUM_PROC", 1)) @@ -45,6 +46,16 @@ def corrected_max_seq_length(self) -> int: @hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="predict.yaml") @torch.inference_mode() def main(cfg: DictConfig) -> None: + """Performs inference on a given dataset using given model. + The outputs are saved to a file in the following JSONL format: + [ + { + "answer": str, + "gold": str + }, + ... + ] + """ config = PredictConfig(**resolve_config(cfg)) logger.info(f"config:\n{pformat(config.model_dump())}") @@ -65,8 +76,8 @@ def main(cfg: DictConfig) -> None: FastLanguageModel.for_inference(model_pack.model) else: - model_pack.model.compile() model_pack.model.eval() + model_pack.model.compile() if config.model.batch_size > 1 and config.model.padding is False: raise ValueError("Padding has to be enabled if batch size > 1.") diff --git a/scripts/sft/predict_vllm.py b/scripts/sft/predict_vllm.py new file mode 100644 index 0000000..254b156 --- /dev/null +++ b/scripts/sft/predict_vllm.py @@ -0,0 +1,86 @@ +import json +import os +from pathlib import Path +from pprint import pformat +from typing import Any + +import hydra +import torch +from datasets import load_dataset +from loguru import logger +from omegaconf import DictConfig +from openai import BaseModel +from pydantic import Field +from vllm import LLM, SamplingParams + +from juddges.config import DatasetConfig, LLMConfig +from juddges.preprocessing.context_truncator import ContextTruncator +from juddges.preprocessing.text_encoder import TextEncoderForEvalPlainTextFormat +from juddges.settings import CONFIG_PATH +from juddges.utils.config import resolve_config + +torch.set_float32_matmul_precision("high") +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +NUM_PROC = int(os.getenv("NUM_PROC", 1)) + + +class PredictConfig(BaseModel, extra="forbid"): + model: LLMConfig + dataset: DatasetConfig + device_map: str + output_file: Path + truncate_context: bool + generate_kwargs: dict[str, Any] = Field(default_factory=dict) + random_seed: int + + @property + def corrected_max_seq_length(self) -> int: + return self.model.max_seq_length - self.dataset.max_output_tokens + + +@hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="predict.yaml") +@torch.inference_mode() +def main(cfg: DictConfig) -> None: + config = PredictConfig(**resolve_config(cfg)) + logger.info(f"config:\n{pformat(config.model_dump())}") + + output_file = Path(config.output_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + + ds = load_dataset(config.dataset.name, split="test") + + llm = LLM( + model=config.model.name, + quantization="bitsandbytes", + load_format="bitsandbytes", + enable_lora=True, + qlora_adapter_name_or_path=config.model.adapter_path, + max_model_len=config.model.max_seq_length, + max_num_seqs=config.model.batch_size, + ) + + truncator = ContextTruncator( + tokenizer=llm.llm_engine.tokenizer.get_lora_tokenizer(), + max_length=config.corrected_max_seq_length, + ) + encoder = TextEncoderForEvalPlainTextFormat(truncator=truncator) + ds = ds.map(encoder, num_proc=NUM_PROC) + ds = ds.select(range(10)) + + params = SamplingParams( + max_tokens=config.generate_kwargs.get("max_new_tokens", 100), + temperature=config.generate_kwargs.get("temperature", 0.0), + ) + + outputs = llm.generate( + prompts=ds["final_input"], + sampling_params=params, + ) + results = [{"answer": ans, "gold": gold} for ans, gold in zip(outputs, ds["output"])] + + with open(output_file, "w") as f: + json.dump(results, f, indent="\t", ensure_ascii=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/sft/predict_with_api.py b/scripts/sft/predict_with_api.py index 508d983..b157c85 100644 --- a/scripts/sft/predict_with_api.py +++ b/scripts/sft/predict_with_api.py @@ -19,7 +19,7 @@ from juddges.config import DatasetConfig from juddges.preprocessing.context_truncator import ContextTruncatorTiktoken -from juddges.preprocessing.text_encoder import TextEncoderForOpenAIEval +from juddges.preprocessing.text_encoder import TextEncoderForEvalPlainTextFormat from juddges.settings import CONFIG_PATH from juddges.utils.config import resolve_config @@ -55,7 +55,7 @@ def main(cfg: DictConfig) -> None: gold_outputs = [item["output"] for item in ds] truncator = ContextTruncatorTiktoken(model=config.model_version, max_length=config.max_seq_len) - encoder = TextEncoderForOpenAIEval(truncator=truncator) + encoder = TextEncoderForEvalPlainTextFormat(truncator=truncator) ds = ds.map(encoder, num_proc=NUM_PROC) predictor = OpenAIPredictor(config=config) diff --git a/scripts/utils/convert_deepspeed.py b/scripts/utils/convert_deepspeed.py new file mode 100644 index 0000000..c42c656 --- /dev/null +++ b/scripts/utils/convert_deepspeed.py @@ -0,0 +1,88 @@ +import os +import shutil +import subprocess +from pathlib import Path +from pprint import pformat + +import typer +from loguru import logger +from tqdm import tqdm + +SCRIPT_NAME = "zero_to_fp32.py" +CONVERTED_MODEL_PATTERN = "model*.safetensors" + + +def main( + root_dir: Path = typer.Option(), + adapter_only: bool = typer.Option(False, help="Only convert adapter"), + remove: bool = typer.Option(False, help="Removes original deepspeed checkpoints"), + remove_for_converted: bool = typer.Option( + False, help="Removes only original deepspeed checkpoints for already converted models" + ), +) -> None: + checkpoint_dirs = [script_file.parent for script_file in root_dir.rglob(SCRIPT_NAME)] + logger.info(f"Found {len(checkpoint_dirs)} checkpoints to convert:\n{pformat(checkpoint_dirs)}") + for ckpt_dir in tqdm(checkpoint_dirs, desc="Converting checkpoints"): + logger.info(f"Converting {ckpt_dir}") + if list(ckpt_dir.glob(CONVERTED_MODEL_PATTERN)): + logger.warning(f"Model already converted, skipping {ckpt_dir}") + if remove_for_converted: + logger.info(f"Removing deepspeed artifacts for {ckpt_dir}") + remove_deepspeed_artifacts(ckpt_dir) + continue + else: + convert(ckpt_dir) + + # deepspeed saves model as model.safetensors, need to rename it to adapter_model.safetensors + if adapter_only: + # there should be (almost) empty adapter_model.safetensors + assert (ckpt_dir / "adapter_model.safetensors").exists() + for model_file in ckpt_dir.glob("model*.safetensors"): + model_file.rename( + model_file.with_stem(model_file.stem.replace("model", "adapter_model")) + ) + + if remove: + remove_deepspeed_artifacts(ckpt_dir) + + +def convert(ckpt_dir: Path) -> None: + script_file = ckpt_dir / SCRIPT_NAME + step_dir = get_latest_step_dir(ckpt_dir) + logger.info(f"Converting {step_dir}") + cmd = [ + "python", + str(script_file), + str(ckpt_dir), # checkpoint_dir + str(ckpt_dir), # output_dir + "--safe_serialization", # writes as safetensors file + "--max_shard_size", + "5GB", + "--tag", + step_dir.name, # points to directory globalstep + ] + env = os.environ.copy() | {"CUDA_VISIBLE_DEVICES": "-1"} + subprocess.run(cmd, check=True, env=env) + + +def remove_deepspeed_artifacts(ckpt_dir: Path) -> None: + step_dir = get_latest_step_dir(ckpt_dir) + logger.info(f"Removing {step_dir}") + shutil.rmtree(step_dir) + + for rng_file in ckpt_dir.glob("rng_state_*.pth"): + os.remove(rng_file) + + os.remove(ckpt_dir / SCRIPT_NAME) + os.remove(ckpt_dir / "latest") + os.remove(ckpt_dir / "scheduler.pt") + + +def get_latest_step_dir(ckpt_dir: Path) -> Path: + with open(ckpt_dir / "latest") as f: + step_dirname = f.read().strip() + return ckpt_dir / step_dirname + + +if __name__ == "__main__": + typer.run(main)