diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index f3056a225a9b..4e203a71db60 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -653,8 +653,15 @@ def no_pool_bootstrap_stderr(f, xs, iters): setattr(lm, model_family, getattr(lm, model_family).half().to(device)) lm._device = device else: - lm = lm_eval.models.get_model(model_family).create_from_arg_string( - f"pretrained={model_name}", {"device": get_accelerator().device_name()}) + if get_accelerator().device_name() == 'hpu': + #lm_eval not supporting HPU device, so get model with CPU and move it to HPU. + lm = lm_eval.models.get_model(model_family).create_from_arg_string(f"pretrained={model_name}", + {"device": "cpu"}) + setattr(lm, model_family, getattr(lm, model_family).to(device)) + lm._device = device + else: + lm = lm_eval.models.get_model(model_family).create_from_arg_string( + f"pretrained={model_name}", {"device": get_accelerator().device_name()}) get_accelerator().synchronize() start = time.time()