-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CTranslate2 Benchmark with Mistral Support (#170)
* Added support for BaseClass and mistral with memory profiling * removed docker support with latest ctranslate release * Added latest ctranslate2 version * Removed runs with docker and added mistral model support * removed docker support and added mistral support * Added performance logs for mistral and llama * engine specific readme with qualitative comparision
- Loading branch information
1 parent
ebd217b
commit bc8929d
Showing
8 changed files
with
274 additions
and
274 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,158 +1,115 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import sys | ||
import time | ||
from collections import defaultdict | ||
|
||
import ctranslate2 | ||
import numpy as np | ||
import sentencepiece as spm | ||
|
||
logging.getLogger("ctranslate2").setLevel(logging.ERROR) | ||
logging.basicConfig( | ||
stream=sys.stdout, | ||
level=logging.INFO, | ||
format="%(asctime)s - %(levelname)s - %(message)s", | ||
) | ||
|
||
B_INST, E_INST = "[INST]", "[/INST]" | ||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | ||
|
||
|
||
def get_compute_types(device): | ||
compute_types = set() | ||
if device in ("cuda", "cpu"): | ||
return set(ctranslate2.get_supported_compute_types(device)) | ||
else: | ||
return compute_types | ||
|
||
|
||
class CTranslateBenchmark: | ||
def __init__(self, model_path, device, compute_type): | ||
self.model_path = model_path | ||
self.results = [] | ||
self.device = device | ||
self.compute_type = compute_type | ||
|
||
def load_model(self): | ||
self.generator = ctranslate2.Generator( | ||
self.model_path, | ||
device=self.device, | ||
compute_type=self.compute_type, | ||
from transformers import AutoTokenizer | ||
|
||
# have to hard code this thing | ||
sys.path.append(os.getcwd()) | ||
|
||
from common.base import BaseBenchmarkClass # noqa | ||
from common.utils import launch_cli, make_report # noqa | ||
|
||
|
||
class CTranslateBenchmark(BaseBenchmarkClass): | ||
def __init__( | ||
self, | ||
model_path: str, | ||
model_name: str, | ||
benchmark_name: str, | ||
precision: str, | ||
device: str, | ||
experiment_name: str, | ||
) -> None: | ||
assert precision in ["float32", "float16", "int8"], ValueError( | ||
"Precision other than: 'float32', 'float16', 'int8' are not supported" | ||
) | ||
self.sp = spm.SentencePieceProcessor( | ||
os.path.join(self.model_path, "tokenizer.model") | ||
super().__init__( | ||
model_path=model_path, | ||
model_name=model_name, | ||
benchmark_name=benchmark_name, | ||
precision=precision, | ||
device=device, | ||
experiment_name=experiment_name, | ||
) | ||
|
||
def load_model_and_tokenizer(self): | ||
self.model = ctranslate2.Generator(self.model_path, device=self.device) | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) | ||
return self | ||
|
||
def run_model(self, prompt, max_tokens): | ||
prompt_tokens = ["<s>"] + self.sp.encode_as_pieces( | ||
f"{B_INST} {prompt.strip()} {E_INST}" | ||
) | ||
start = time.time() | ||
step_results = self.generator.generate_tokens( | ||
prompt_tokens, | ||
max_length=max_tokens, | ||
sampling_temperature=0.6, | ||
sampling_topk=20, | ||
sampling_topp=1, | ||
) | ||
count = 0 | ||
for _ in self.generate_words(step_results): | ||
count += 1 | ||
return count / (time.time() - start) | ||
|
||
def benchmark(self, prompt, max_tokens, repetitions): | ||
for i in range(repetitions): | ||
logging.info( | ||
f"Running repetition [{str(i+1).zfill(len(str(repetitions)))}/{repetitions}]" | ||
def preprocess(self, prompt: str, chat_mode: bool = True, for_benchmarks=True): | ||
if chat_mode: | ||
template = self.get_chat_template_with_instruction( | ||
prompt=prompt, for_benchmarks=for_benchmarks | ||
) | ||
tokens_per_second = self.run_model(prompt, max_tokens) | ||
self.results.append(tokens_per_second) | ||
prompt = self.tokenizer.apply_chat_template(template, tokenize=False) | ||
|
||
def generate_words(self, step_results): | ||
tokens_buffer = [] | ||
tokenized_input = self.tokenizer.convert_ids_to_tokens( | ||
self.tokenizer.encode(prompt) | ||
) | ||
return { | ||
"prompt": prompt, | ||
"input_tokens": tokenized_input, | ||
"tensor": None, | ||
"num_input_tokens": len(tokenized_input), | ||
} | ||
|
||
for step_result in step_results: | ||
is_new_word = step_result.token.startswith("▁") | ||
def run_model( | ||
self, inputs: dict, max_tokens: int, temperature: float = 0.1 | ||
) -> dict: | ||
tokenized_input = inputs["input_tokens"] | ||
num_input_tokens = inputs["num_input_tokens"] - 1 | ||
|
||
if is_new_word and tokens_buffer: | ||
word = self.sp.decode(tokens_buffer) | ||
if word: | ||
yield word | ||
tokens_buffer = [] | ||
output = self.model.generate_batch( | ||
[tokenized_input], max_length=max_tokens, sampling_temperature=0.1 | ||
) | ||
|
||
tokens_buffer.append(step_result.token_id) | ||
output_tokens = output[0].sequences_ids[0][num_input_tokens:] | ||
output_prompt = self.tokenizer.decode(output_tokens, skip_special_tokens=True) | ||
return { | ||
"output_prompt": output_prompt, | ||
"output_tokens": output_tokens, | ||
"num_output_tokens": len(output_tokens), | ||
} | ||
|
||
if tokens_buffer: | ||
word = self.sp.decode(tokens_buffer) | ||
if word: | ||
yield word | ||
def postprocess(self, output: dict) -> str: | ||
return output["output_prompt"] | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="CTranslate Benchmark Llama model.") | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
help="The prompt for the model.", | ||
) | ||
parser.add_argument("--max_tokens", type=int, help="The maximum number of tokens.") | ||
parser.add_argument( | ||
"--repetitions", | ||
type=int, | ||
help="The number of repetitions for the benchmark.", | ||
) | ||
parser.add_argument( | ||
"--device", | ||
help="Device to use for the benchmark.", | ||
) | ||
parser.add_argument( | ||
"--log_file", | ||
type=str, | ||
help="Path to the log file for writing logs (in append mode).", | ||
) | ||
parser.add_argument( | ||
"--models_dir", | ||
type=str, | ||
help="Path to the models directory.", | ||
) | ||
parser = launch_cli(description="CTransformers Benchmark.") | ||
args = parser.parse_args() | ||
if args.device == "metal": | ||
logging.info(f"Skipping benchmark with device={args.device}") | ||
sys.exit(0) | ||
|
||
logging.info( | ||
f"Running benchmark with: max_tokens={args.max_tokens} prompt={args.prompt} " | ||
+ f"repetitions={args.repetitions} device={args.device}" | ||
model_folder = os.path.join(os.getcwd(), "models") | ||
model_name = ( | ||
f"{args.model_name}-2-7b-chat-ctranslate2-" | ||
if args.model_name == "llama" | ||
else f"{args.model_name}-7b-v0.1-instruct-ctranslate2-" | ||
) | ||
report = defaultdict(lambda: defaultdict(float)) | ||
compute_types = get_compute_types(args.device) | ||
|
||
for compute_type in compute_types.intersection({"float32", "float16", "int8"}): | ||
logging.info(f"Running ctranslate benchmark with {compute_type}") | ||
ctranslate_bench = CTranslateBenchmark( | ||
f"{args.models_dir}/llama-2-7b-ctranslate2-{compute_type}", | ||
device=args.device, | ||
compute_type=compute_type, | ||
).load_model() | ||
ctranslate_bench.benchmark( | ||
max_tokens=args.max_tokens, prompt=args.prompt, repetitions=args.repetitions | ||
) | ||
report["ctranslate"][compute_type] = { | ||
"mean": np.mean(ctranslate_bench.results), | ||
"std": np.std(ctranslate_bench.results), | ||
} | ||
|
||
logging.info("Benchmark report") | ||
with open(args.log_file, "a") as file: | ||
for framework, quantizations in report.items(): | ||
for quantization, stats in quantizations.items(): | ||
logging.info( | ||
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}" | ||
) | ||
print( | ||
f"{framework}, {quantization}: {stats['mean']:.2f} ± {stats['std']:.2f}", | ||
file=file, | ||
) | ||
runner_dict = { | ||
"cuda": [ | ||
{ | ||
"precision": "float32", | ||
"model_path": os.path.join(model_folder, model_name + "float32"), | ||
}, | ||
{ | ||
"precision": "float16", | ||
"model_path": os.path.join(model_folder, model_name + "float16"), | ||
}, | ||
{ | ||
"precision": "int8", | ||
"model_path": os.path.join(model_folder, model_name + "int8"), | ||
}, | ||
] | ||
} | ||
|
||
make_report( | ||
args=args, | ||
benchmark_class=CTranslateBenchmark, | ||
runner_dict=runner_dict, | ||
benchmark_name="CTranslate2", | ||
is_bench_pytorch=False, | ||
) |
Oops, something went wrong.