Skip to content

Commit

Permalink
[Fix/Example] Fix Llama Inference Loading Data Type (#5763)
Browse files Browse the repository at this point in the history
* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3
  • Loading branch information
yuanheng-zhao authored May 30, 2024
1 parent 023ea13 commit 677cbfa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
12 changes: 11 additions & 1 deletion examples/inference/llama/benchmark_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
MEGABYTE = 1024**2
N_WARMUP_STEPS = 2

TORCH_DTYPE_MAP = {
"fp16": torch.float16,
"fp32": torch.float32,
"bf16": torch.bfloat16,
}


CONFIG_MAP = {
"toy": transformers.LlamaConfig(num_hidden_layers=4),
"llama-7b": transformers.LlamaConfig(
Expand Down Expand Up @@ -104,10 +111,13 @@ def print_details_info(model_config, whole_end2end, total_token_num, dtype, coor
def benchmark_inference(args):
coordinator = DistCoordinator()

torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None)
config = CONFIG_MAP[args.model]
config.torch_dtype = torch_dtype
config.pad_token_id = config.eos_token_id

if args.model_path is not None:
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
else:
# Random weights
Expand Down
9 changes: 8 additions & 1 deletion examples/inference/llama/llama_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse

from torch import bfloat16, float16, float32
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

import colossalai
Expand All @@ -12,6 +13,12 @@
MODEL_CLS = AutoModelForCausalLM
POLICY_CLS = NoPaddingLlamaModelInferPolicy

TORCH_DTYPE_MAP = {
"fp16": float16,
"fp32": float32,
"bf16": bfloat16,
}


def infer(args):
# ==============================
Expand All @@ -24,7 +31,7 @@ def infer(args):
# Load model and tokenizer
# ==============================
model_path_or_name = args.model
model = MODEL_CLS.from_pretrained(model_path_or_name)
model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
tokenizer.pad_token = tokenizer.eos_token
# coordinator.print_on_master(f"Model Config:\n{model.config}")
Expand Down

0 comments on commit 677cbfa

Please sign in to comment.