Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Oct 27, 2023
1 parent 71e242e commit f5463b9
Showing 1 changed file with 185 additions and 0 deletions.
185 changes: 185 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,191 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self)
)


class TestDistributedExpand(unittest.TestCase):
def _check_distributed_expand(
self,
shape: Tuple[int, ...],
target_shape: Tuple[int, ...],
input_device_meshs: np.ndarray,
input_shard_specs: Tuple[str, ...],
output_device_meshs: np.ndarray,
output_shard_specs: Tuple[str, ...],
):
assert all(len(mesh.shape) == 1 for mesh in input_device_meshs)
assert all(len(mesh.shape) == 1 for mesh in output_device_meshs)
assert len(input_device_meshs) == len(input_shard_specs)
assert len(output_device_meshs) == len(output_shard_specs)

input_device_mesh_shapes = []
input_device_mesh_elements = []
for device_mesh in input_device_meshs:
device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh)
input_device_mesh_shapes.append(device_mesh_shape)
input_device_mesh_elements.append(device_mesh_element)

output_device_mesh_shapes = []
output_device_mesh_elements = []
for device_mesh in output_device_meshs:
device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh)
output_device_mesh_shapes.append(device_mesh_shape)
output_device_mesh_elements.append(device_mesh_element)

@onnxscript.script()
def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64):
return MICROSOFT_OPSET.DistributedExpand(
data_tensor,
shape_tensor,
input_device_mesh_shapes=input_device_mesh_shapes,
input_device_mesh_elements=input_device_mesh_elements,
input_shard_specs=input_shard_specs,
output_device_mesh_shapes=output_device_mesh_shapes,
output_device_mesh_elements=output_device_mesh_elements,
output_shard_specs=output_shard_specs,
)

rank = comm.Get_rank()
data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape)
shape_tensor = np.array(
target_shape,
dtype=np.int64,
)

local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0])
assert "S" not in input_shard_specs[1], "Shape should not be sharded."

expected = data_tensor * np.ones(shape_tensor)
local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0])

onnx_model = distributed_expand_instance.to_model_proto(
input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]],
output_types=[FLOAT[tuple(local_expected.shape)]],
)

# Each MPI process owns a sharded model.
sess = ort.InferenceSession(
onnx_model.SerializeToString(),
providers=["CUDAExecutionProvider"],
provider_options=[{"device_id": str(rank)}],
)

# Each MPI process executes its sharded model.
# The result is `local` tensor stored on a specific MPI rank
# instead of `logical` tensor.
result = sess.run(
None,
{
"data_tensor": local_data_tensor,
"shape_tensor": shape_tensor,
},
)

# Compare local tensor and the corresponding logical sub-tensor
# obtained by sharding logical tensor following output's sharding spec.
np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8)

def test_expand_sharded_on_expanded_axis(self):
# data: shape=[8,1], spec=(RR, [0,1])
# shape: shape=[2], spec=(R, [0,1]), value=[1,4]
# output: shape=[8,4], spec=(RS, [0,1])
self._check_distributed_expand(
shape=(
8,
1,
),
target_shape=(
8,
4,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RR", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("RS[0]",),
)

def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self):
# data: shape=[8,1], spec=(RR, [0,1])
# shape: shape=[2], spec=(R, [0,1]), value=[1,4]
# output: shape=[8,4], spec=(RS, [0,1])
self._check_distributed_expand(
shape=(
8,
1,
),
target_shape=(
8,
8,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RR", "R"),
output_device_meshs=[np.array([0, 1, 0, 1])],
output_shard_specs=("RS[0]",),
)

def test_expand_replicated_on_expanded_axis(self):
# data: shape=[8,1], spec=(RR, [0,1])
# shape: shape=[2], spec=(R, [0,1]), value=[1,4]
# output: shape=[8,4], spec=(RR, [0,1])
self._check_distributed_expand(
shape=(
8,
1,
),
target_shape=(
1,
4,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RR", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("RR",),
)

def test_expand_with_pass_through_sharding_spec(self):
# data: shape=[8,1], spec=(SR, [0,1])
# shape: shape=[2], spec=(R, [0,1]), value=[1,4]
# output: shape=[8,4], spec=(SR, [0,1])
self._check_distributed_expand(
shape=(
8,
1,
),
target_shape=(
1,
4,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=(
"S[0]R",
"R",
),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("S[0]R",),
)

def test_expand_in_tiny_llama(self):
# data: shape=[2,4,256,4], spec=(RSRR, [0,1])
# shape: shape=[4], spec=(R, [0,1,2,3]), value=[2,4,256,4]
# output: shape=[2,4,256,4], spec=(RSRR, [0,1])
self._check_distributed_expand(
shape=(
2,
4,
256,
4,
),
target_shape=(
2,
4,
256,
4,
),
input_device_meshs=[np.array([0, 1])] * 2,
input_shard_specs=("RS[0]RR", "R"),
output_device_meshs=[np.array([0, 1])],
output_shard_specs=("RS[0]RR",),
)


class TestDistributed(unittest.TestCase):
def test_matmul_rs_sr_rr(self):
# It means 1-D tensor with single element: [2].
Expand Down

0 comments on commit f5463b9

Please sign in to comment.