Skip to content

Commit

Permalink
update eval and fix example
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Sep 13, 2024
1 parent 3275df9 commit c920c7a
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 513 deletions.
6 changes: 1 addition & 5 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,15 @@ def get_library_version(library_name):


def eval(args):
quantization_config = AutoRoundConfig(backend=args.device)
device_str = detect_device(args.device)
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map=device_str, quantization_config=quantization_config)
model_args = f"pretrained={args.model},trust_remote_code={not args.disable_trust_remote_code}"
if isinstance(args.tasks, str):
tasks = args.tasks.split(',')
res = simple_evaluate(
model="hf",
model_args=model_args,
user_model=user_model,
tasks=tasks,
device=device_str,
batch_size=args.eval_bs)

from lm_eval.utils import make_table # pylint: disable=E0401
Expand Down
285 changes: 35 additions & 250 deletions auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,269 +17,54 @@
import time
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
import torch
import lm_eval
from lm_eval import simple_evaluate as lm_simple_evaluate

import lm_eval.api.metrics # pylint: disable=E0401
import lm_eval.api.registry # pylint: disable=E0401
import lm_eval.models # pylint: disable=E0401
from lm_eval.caching.cache import delete_cache # pylint: disable=E0401
from lm_eval.evaluator import evaluate # pylint: disable=E0401
from lm_eval.evaluator_utils import run_task_tests # pylint: disable=E0401
from lm_eval.loggers.utils import add_env_info, get_git_commit_hash # pylint: disable=E0401
from lm_eval.tasks import TaskManager, get_task_dict # pylint: disable=E0401
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string # pylint: disable=E0401

if TYPE_CHECKING:
from lm_eval.api.model import LM
from lm_eval.tasks import Task


@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
user_model = None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
use_cache: Optional[str] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
log_samples: bool = True,
gen_kwargs: Optional[str] = None,
task_manager: Optional[TaskManager] = None,
verbosity: str = "INFO",
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
user_model = None, ##user model does not support tensor parallelism
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str, dict]
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, dict, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME
if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int or str, optional
Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param use_cache: str, optional
A path to a sqlite db file for caching model responses. `None` if not caching.
:param cache_requests: bool, optional
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
:param rewrite_requests_cache: bool, optional
Rewrites all of the request cache if set to `True`. `None` if not desired.
:param delete_requests_cache: bool, optional
Deletes all of the request cache if set to `True`. `None` if not desired.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1,
limit is a percentage of the total number of examples.
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:param write_out: bool
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param gen_kwargs: str
String arguments for model generation
Ignored for all tasks with loglikelihood output_type
:param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated
:param random_seed: int
Random seed for python's random module. If set to None, the seed will not be set.
:param numpy_random_seed: int
Random seed for numpy. If set to None, the seed will not be set.
:param torch_random_seed: int
Random seed for torch. If set to None, the seed will not be set.
:return
Dictionary of results
"""
from auto_round.auto_quantizer import AutoHfQuantizer
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
start_date = time.time()

if delete_requests_cache:
eval_logger.info("Deleting requests cache...")
delete_cache()

seed_message = []
if random_seed is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
seed_message.append(f"Setting random seed to {random_seed}")
random.seed(random_seed)

if numpy_random_seed is not None:
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed)

if torch_random_seed is not None:
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)

if seed_message:
eval_logger.info(" | ".join(seed_message))

if tasks is None:
tasks = []
if len(tasks) == 0:
raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names."
)

if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
**kwargs):

try:
from auto_round import AutoRoundConfig
except:
from auto_round.auto_quantizer import AutoHfQuantizer

if model_args is None:
model_args = ""

if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
if gen_kwargs == "":
gen_kwargs = None

if isinstance(model, str):
if model_args is None:
model_args = ""

if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)

else:
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
else:
if not isinstance(model, lm_eval.api.model.LM):
raise TypeError
lm = model

if use_cache is not None:
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank"
+ str(lm.rank)
+ ".db",
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
},
)
if user_model is not None:
lm._model = user_model
return lm_simple_evaluate(
model=lm,
model_args=model_args,
batch_size=batch_size,
max_batch_size=max_batch_size,
device=device,
**kwargs)

if task_manager is None:
task_manager = TaskManager(verbosity)

task_dict = get_task_dict(tasks, task_manager)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if isinstance(task_obj, tuple):
_, task_obj = task_obj
if task_obj is None:
continue

if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
)

if predict_only:
log_samples = True
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
# we have to change the class properties post-hoc. This is pretty hacky.
task_obj.override_metric(metric_name="bypass")

# override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config."
" Manual configuration will be ignored."
)
else:
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
task_obj.set_config(key="num_fewshot", value=0)

if check_integrity:
run_task_tests(task_list=tasks)

results = evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
write_out=write_out,
log_samples=log_samples,
verbosity=verbosity,
)

if lm.rank == 0:
if isinstance(model, str):
model_name = model
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_name = model.config._name_or_path
else:
model_name = type(model).__name__

# add info about the model and few shot config
results["config"] = {
"model": model_name,
"model_args": model_args,
"batch_size": batch_size,
"batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
),
"device": device,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
}
results["git_hash"] = get_git_commit_hash()
results["date"] = start_date
add_env_info(results) # additional environment info to results
return results
else:
return None
7 changes: 4 additions & 3 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ reduce the seqlen to 512 (potential accuracy drop)
or combine them

- **Reduced CPU Memory Usage:**
set "--low_cpu_mem_mode 1" to use block-wise mode, load the weights from disk of each block when tuning and release the memory of the block after tuning. (more tuning cost)

set "--low_cpu_mem_mode 2" to use layer-wise mode, load the weights of each layer from disk when tuning, minimum memory consumption and also slowest running speed.
- set "--low_cpu_mem_mode 1" to use block-wise mode, load the weights from disk of each block when tuning and release the memory of the block after tuning. (more tuning cost)

- set "--low_cpu_mem_mode 2" to use layer-wise mode, load the weights of each layer from disk when tuning, minimum memory consumption and also slowest running speed.


- **Speedup the tuning:**
Expand Down Expand Up @@ -132,7 +133,7 @@ CUDA_VISIBLE_DEVICES=1,2 python3 eval/evaluation.py --model_name /save_model_pat
- User could also perform evaluation on Intel Gaudi-2 using the following script.
```bash
python3 eval_042/evaluation.py --model_name /saved_model_path/ --tasks mmlu --device hpu
python3 eval/evaluation.py --model_name /saved_model_path/ --tasks mmlu --device hpu
```
You can also utilize the official lm_eval [link](https://github.com/EleutherAI/lm-evaluation-harness/tree/main?tab=readme-ov-file#basic-usage).
Expand Down
Loading

0 comments on commit c920c7a

Please sign in to comment.