Skip to content

Commit

Permalink
Remove jax.host_count alias.
Browse files Browse the repository at this point in the history
Resolves warnings about jax.host_count being renamed to jax.process_count.

PiperOrigin-RevId: 566401218
  • Loading branch information
texasmichelle authored and t5-copybara committed Sep 18, 2023
1 parent ea66ec8 commit 4d28257
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
8 changes: 4 additions & 4 deletions t5x/contrib/moe/checkpoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,25 +286,25 @@ def get_mesh_axes(self, train_state):
'jax.experimental.multihost_utils.sync_global_devices', return_value=None
)
@mock.patch('time.time', return_value=0)
@mock.patch('jax.host_count')
@mock.patch('jax.process_count')
@mock.patch('jax.process_index')
def call_host_checkpointer(
self,
train_state,
process_index,
host_count,
process_count,
partitioner,
fn,
save_dtype,
ds_iter,
mock_process_index,
mock_host_count,
mock_process_count,
unused_mock_host_time,
unused_mock_sync_devices,
restore_dtype=np.float32,
):
mock_process_index.return_value = process_index
mock_host_count.return_value = host_count
mock_process_count.return_value = process_count

checkpointer = checkpoints.UpcycleCheckpointer(
train_state,
Expand Down
6 changes: 4 additions & 2 deletions t5x/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bounds_from_last_device(last_device: jax.Device) -> HardwareMesh:
else:
# On non-TPU platforms, the "mesh" is hosts x devices per host in order
# to take advantage of faster within-host interconnect.
return jax.host_count(), jax.local_device_count()
return jax.process_count(), jax.local_device_count()


def get_coords(device: jax.Device) -> HardwareMesh:
Expand Down Expand Up @@ -259,7 +259,9 @@ def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:

def get_cpu_mesh() -> Mesh:
"""Trivial mesh for CPU Testing."""
devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object)
devices = np.empty(
(jax.process_count(), jax.local_device_count()), dtype=object
)
for device in jax.devices():
devices[device.process_index, device.id % jax.local_device_count()] = device
return Mesh(devices, ['data', 'model'])
Expand Down
6 changes: 3 additions & 3 deletions t5x/partitioning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@

class PartitioningTest(absltest.TestCase):

@mock.patch('jax.host_count')
@mock.patch('jax.process_count')
@mock.patch('jax.local_device_count')
def test_bounds_from_last_device(self, local_device_count, host_count):
def test_bounds_from_last_device(self, local_device_count, process_count):
last_device = mock.Mock(coords=(3, 3, 3), core_on_chip=1)
tpu_bounds = partitioning.bounds_from_last_device(last_device)
self.assertEqual(tpu_bounds, (4, 4, 4, 2))

last_device = mock.Mock(spec=[])
host_count.return_value = 1
process_count.return_value = 1
local_device_count.return_value = 4
non_tpu_bounds = partitioning.bounds_from_last_device(last_device)
self.assertEqual(non_tpu_bounds, (1, 4))
Expand Down

0 comments on commit 4d28257

Please sign in to comment.