From da61ecf678017a19ab15e2f45b6e0bad6997b14b Mon Sep 17 00:00:00 2001 From: Artur Fierka Date: Mon, 16 Dec 2024 16:52:31 +0100 Subject: [PATCH] Unit scales in FP8 CI scenarios (#633) --- .../configs/Meta-Llama-3.1-8B-Instruct-fp8.yaml | 6 +++--- .../lm-eval-harness/inc_unit_scales_config.json | 16 ++++++++++++++++ .../lm-eval-harness/test_lm_eval_correctness.py | 8 +++----- 3 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 .jenkins/lm-eval-harness/inc_unit_scales_config.json diff --git a/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-fp8.yaml b/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-fp8.yaml index 80a8c522bc5a0..5c1cd657e8e36 100644 --- a/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-fp8.yaml +++ b/.jenkins/lm-eval-harness/configs/Meta-Llama-3.1-8B-Instruct-fp8.yaml @@ -5,10 +5,10 @@ tasks: - name: "gsm8k_cot_llama" metrics: - name: "exact_match,strict-match" - value: 0.8317 + value: 0.664 - name: "exact_match,flexible-extract" - value: 0.8355 -limit: null + value: 0.676 +limit: 250 num_fewshot: 8 dtype: "bfloat16" fewshot_as_multiturn: true diff --git a/.jenkins/lm-eval-harness/inc_unit_scales_config.json b/.jenkins/lm-eval-harness/inc_unit_scales_config.json new file mode 100644 index 0000000000000..cd6589c811417 --- /dev/null +++ b/.jenkins/lm-eval-harness/inc_unit_scales_config.json @@ -0,0 +1,16 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "unit_scale", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "dump_stats_path": "" +} \ No newline at end of file diff --git a/.jenkins/lm-eval-harness/test_lm_eval_correctness.py b/.jenkins/lm-eval-harness/test_lm_eval_correctness.py index 9272123034350..55d633e51ce97 100644 --- a/.jenkins/lm-eval-harness/test_lm_eval_correctness.py +++ b/.jenkins/lm-eval-harness/test_lm_eval_correctness.py @@ -27,12 +27,10 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) -def setup_fp8(model_path, device_type): - flavor = f"g{device_type[-1]}" - normalized_model_name = Path(model_path).parts[-1].lower() +def setup_fp8(): os.environ[ "QUANT_CONFIG"] = \ - f"/software/data/vllm-benchmarks/inc/{normalized_model_name}/maxabs_quant_{flavor}.json" + "inc_unit_scales_config.json" def fail_on_exit(): @@ -147,7 +145,7 @@ def test_lm_eval_correctness(record_xml_attribute, record_property): # Set up environment for FP8 inference if eval_config.get("fp8"): - setup_fp8(eval_config["model_name"], platform) + setup_fp8() # Launch eval requests. start_time = time.perf_counter() results = launch_lm_eval(eval_config)