Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695
) This PR introduces - New data structure to represent kernel-level (aka node-level or op-level) tensor sharding informaiton. I consider it as the fundamentaion of ONNX distribtued inference. - Building blocks for distribtued kernels implementation especially stateless implementation for communication ops. - Implementation of DistributedMatMul and its tests. Code structure: - sharding.h/.cc: Function to shard and reshard tensors (calling into NCCL). - sharding_spec.h/.cc: Representation of how a tensor is sharded. - distributed_matmul.h/.cc: Implementation of tensor parallel MatMul. Inputs and outputs are sharded across devices. - onnxruntime_test_distributed.py: distributed operator tests. Example of specifying sharding information ```python @onnxscript.script() def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: # Run MatMul by sharding x along column axis and w along row axis on # 2 GPUs. return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, device_mesh_shape=[2], device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "S[0]R"], output_shard_specs=["RR"], ) onnx_model = matmul_rs_sr_rr.to_model_proto( input_types=[FLOAT[2, "s"], FLOAT["s", 2]], output_types=[FLOAT[2, 2]], ) ``` In this example, the device mesh can be visualized as 1-D tensor, `[0, 1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the 0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is sharded across `[0, 1]` as well. C++ classes to represent tensor sharding (copied from sharding_spec.h): ```cpp class DeviceMesh { public: // [Device Mesh and Tensor Sharding for Tensor Parallel] // Device mesh is a tensor of device indices. // A tensor can then be partitioned along specific mesh axes. // // Assume we have 4 GPUs indexed by 0, 1, 2, and 3. // Let's consider some examples. // 1. 1D device mesh [0, 1, 2, 3]. In this case, // device_mesh_shape is [4] and device_mesh_elements // is [0, 1, 2, 3]. // If we want to shard a 2-D tensor along its axis 1, the // corresponding sharding spec is a string "RS[0]". // 2. 2D device mesh [[0, 1], [2, 3]]. In this case, // device_mesh_shape is [2, 2] and device_mesh_elements // is [0, 1, 2, 3]. // If we want to shard a 2-D tensor's // rows along mesh axis 1 and // columns along mesh axis 0, the // corresponding sharding spec is a string "S[1]S[0]". // If that 2-D tensor's value is np.array([[5, 6], [7, 8]]), // GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding // proccess. // - Start with a 2-D device mesh [[0, 1], [2, 3]] and // a 2-D tensor [[5, 6], [7, 8]] // - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]] // - Split GPU mesh along axis 1 and tensor along // axis 0 for "S[1]" in "S[1]S[0]" // - GPU: [[0], [2]], Tensor: [[5, 6]] // GPU: [[1], [3]], Tensor: [[7, 8]] // - Split GPU mesh along axis 0 and tensor along // axis 1 for "S[0]" in "S[1]S[0]" // - GPU: [[0]], Tensor: [[5]] // - GPU: [[2]], Tensor: [[6]] // - GPU: [[1]], Tensor: [[7]] // - GPU: [[3]], Tensor: [[8]] // Actual shape of device mesh represented by `device_mesh_elements`. std::vector<int64_t> device_mesh_shape; // Flattened device mesh. std::vector<int64_t> device_mesh_elements; }; class AxisPartitionSpec { // [Device Mesh and Tensor Sharding for Tensor Parallel] // This class is the in-memory representation of // 1. if a tensor is sharded or not (aka replica), and // 2. which tensor axis is shard by which device mesh axis. // Let's consider sharding 2-D tensor along column axis on // device mesh [0, 1] as an example. // The required sharding spec RS[0] can be represented by // - AxisPartitionSpec(Condition::Replica, -1) // - AxisPartitionSpec(Condition::Shard, 0) public: // Status of a tensor axis. // A tensor axis can be either sharded or replicated // along a device mesh axis. enum class Condition { Replica, Shard }; // This field tells if a tensor axis is sharded or not. Condition cond; // If a tensor axis is sharded, this field tells which device // mesh axis to distribute the shards along. // If a tensor axis is not sharded, this field is ignored. int device_mesh_axis; // A helper to construct a replica spec for a tensor axis. static AxisPartitionSpec CreateReplica() { return AxisPartitionSpec(Condition::Replica, -1); } // A helper to construct a sharding spec for a tensor axis. // This tensor axis is sharded along `device_mesh_axis` in device mesh. static AxisPartitionSpec CreateShard(int device_mesh_axis) { return AxisPartitionSpec(Condition::Shard, device_mesh_axis); } }; class TensorPartitionSpec { // [Device Mesh and Tensor Sharding for Tensor Parallel] // TensorPartitionSpec holds a collection of AxisPartitionSpec and an // associated DeviceMesh. It is responsible for determining how a tensor // should be partitioned across a device mesh. // // Example 1: RS[0] // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects. // - The first object is a Replica, denoting that the first axis of the tensor is // not sharded but is instead replicated. // - The second object is a Shard along the 0-th axis of the device mesh. It denotes // that the second axis of the tensor is sharded along the first axis of the // device mesh. // // Example 2: S[0]RR // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects. // - The first object is a Shard along the 0-th axis of the device mesh, indicating // that the first axis of the tensor is sharded along the first axis of the // device mesh. // - The second and third objects are Replicas, indicating that the second and third // axes of the tensor are not sharded but are instead replicated. public: // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor, // axis_specs[0] is for row axis and axis_specs[1] is for // column axis. axis_specs[i].device_mesh_axis = j means that // tensor axis i is sharded along device mesh axis j. std::vector<AxisPartitionSpec> axis_specs; // device_mesh: DeviceMesh for sharding the associated tensor. // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment. DeviceMesh device_mesh; }; ```
- Loading branch information