diff --git a/tests/unit/common.py b/tests/unit/common.py index c002aa372c8c..1fd83de81f02 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -23,7 +23,7 @@ from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker # Worker timeout for tests that hang -DEEPSPEED_TEST_TIMEOUT = 600 +DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600')) def is_rocm_pytorch(): @@ -81,6 +81,11 @@ def set_accelerator_visible(): match = re.search('Device Type.*GPU', line) if match: num_accelerators += 1 + elif get_accelerator().device_name() == 'hpu': + hl_smi = subprocess.check_output(['hl-smi', "-L"]) + num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode()) + num_accelerators = sorted(num_accelerators, key=int) + os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators) elif get_accelerator().device_name() == 'npu': npu_smi = subprocess.check_output(['npu-smi', 'info', '-l']) num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip()) @@ -90,7 +95,10 @@ def set_accelerator_visible(): subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True)) num_accelerators = cpu_sockets - cuda_visible = ",".join(map(str, range(num_accelerators))) + if isinstance(num_accelerators, list): + cuda_visible = ",".join(num_accelerators) + else: + cuda_visible = ",".join(map(str, range(num_accelerators))) # rotate list based on xdist worker id, example below # wid=0 -> ['0', '1', '2', '3'] @@ -149,6 +157,12 @@ def _get_fixture_kwargs(self, request, func): def _launch_daemonic_procs(self, num_procs): # Create process pool or use cached one master_port = None + + if get_accelerator().device_name() == 'hpu': + if self.reuse_dist_env: + print("Ignoring reuse_dist_env for hpu") + self.reuse_dist_env = False + if self.reuse_dist_env: if num_procs not in self._pool_cache: self._pool_cache[num_procs] = mp.Pool(processes=num_procs) @@ -169,9 +183,10 @@ def _launch_daemonic_procs(self, num_procs): # usually means an environment error and the rest of tests will # hang (causing super long unit test runtimes) pytest.exit("Test hanged, exiting", returncode=1) - - # Tear down distributed environment and close process pools - self._close_pool(pool, num_procs) + finally: + # Regardless of the outcome, ensure proper teardown + # Tear down distributed environment and close process pools + self._close_pool(pool, num_procs) # If we skipped a test, propagate that to this process if any(skip_msgs):