From fce979138ebde19e9b30ebc5f209928695148f16 Mon Sep 17 00:00:00 2001 From: vikram singh shekhawat Date: Wed, 13 Mar 2024 21:44:15 +0530 Subject: [PATCH] Enabled LMCorrectness inference tests on HPU. (#5271) Since lm_eval API(v0.3.0) does not currently support HPU accelerator, so to run LMCorrectness tests on HPU, get lm_eval model with CPU and move it to HPU accelerator. --- tests/unit/inference/test_inference.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index f3056a225a9b..4e203a71db60 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -653,8 +653,15 @@ def no_pool_bootstrap_stderr(f, xs, iters): setattr(lm, model_family, getattr(lm, model_family).half().to(device)) lm._device = device else: - lm = lm_eval.models.get_model(model_family).create_from_arg_string( - f"pretrained={model_name}", {"device": get_accelerator().device_name()}) + if get_accelerator().device_name() == 'hpu': + #lm_eval not supporting HPU device, so get model with CPU and move it to HPU. + lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}", + {"device": "cpu"}) + setattr(lm, model_family, getattr(lm, model_family).to(device)) + lm._device = device + else: + lm = lm_eval.models.get_model(model_family).create_from_arg_string( + f"pretrained={model_name}", {"device": get_accelerator().device_name()}) get_accelerator().synchronize() start = time.time()