From 0eb9330c3ba836911932444caca7fec0cbdad222 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 20:13:48 -0700 Subject: [PATCH] Implement details for d-expand Fix a function call --- .../cuda/collective/distributed_expand.cc | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc index a946e8812d3ff..ec1826d1eabd2 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -26,7 +26,47 @@ DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedK template Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(context != nullptr); - return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported expand pattern."); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape).IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); } ONNX_OPERATOR_TYPED_KERNEL_EX(