Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HPU accelerator support in unit tests. #5162

25 changes: 20 additions & 5 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker

# Worker timeout for tests that hang
DEEPSPEED_TEST_TIMEOUT = 600
DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600'))


def is_rocm_pytorch():
Expand Down Expand Up @@ -81,6 +81,11 @@ def set_accelerator_visible():
match = re.search('Device Type.*GPU', line)
if match:
num_accelerators += 1
elif get_accelerator().device_name() == 'hpu':
loadams marked this conversation as resolved.
Show resolved Hide resolved
hl_smi = subprocess.check_output(['hl-smi', "-L"])
num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode())
num_accelerators = sorted(num_accelerators, key=int)
os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators)
elif get_accelerator().device_name() == 'npu':
npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
Expand All @@ -90,7 +95,10 @@ def set_accelerator_visible():
subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
num_accelerators = cpu_sockets

cuda_visible = ",".join(map(str, range(num_accelerators)))
if isinstance(num_accelerators, list):
cuda_visible = ",".join(num_accelerators)
else:
cuda_visible = ",".join(map(str, range(num_accelerators)))

# rotate list based on xdist worker id, example below
# wid=0 -> ['0', '1', '2', '3']
Expand Down Expand Up @@ -149,6 +157,12 @@ def _get_fixture_kwargs(self, request, func):
def _launch_daemonic_procs(self, num_procs):
# Create process pool or use cached one
master_port = None

if get_accelerator().device_name() == 'hpu':
if self.reuse_dist_env:
print("Ignoring reuse_dist_env for hpu")
self.reuse_dist_env = False

if self.reuse_dist_env:
if num_procs not in self._pool_cache:
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
Expand All @@ -169,9 +183,10 @@ def _launch_daemonic_procs(self, num_procs):
# usually means an environment error and the rest of tests will
# hang (causing super long unit test runtimes)
pytest.exit("Test hanged, exiting", returncode=1)

# Tear down distributed environment and close process pools
self._close_pool(pool, num_procs)
finally:
# Regardless of the outcome, ensure proper teardown
# Tear down distributed environment and close process pools
self._close_pool(pool, num_procs)

# If we skipped a test, propagate that to this process
if any(skip_msgs):
Expand Down
Loading