diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc index a946e8812d3ff..4691c8b5935a5 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -26,7 +26,47 @@ DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedK template Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(context != nullptr); - return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported expand pattern."); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeExpandOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape).IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); } ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index 368c167f58641..0c09917b9f0e8 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -95,7 +95,7 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor.Shape().Size()}; TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape)); + ORT_RETURN_IF_ERROR(ComputeExpandOutputShape(Node().Name(), input_data_tensor.Shape(), output_dims, output_shape)); auto& output_tensor = *ctx->Output(0, output_shape); if (0 == output_shape.Size()) { return Status::OK(); @@ -203,7 +203,7 @@ std::unique_ptr FuncExpand( TensorShape output_shape(output_dims); ORT_ENFORCE( - ComputeOutputShape( + ComputeExpandOutputShape( cuda_kernel->Node().Name(), input_data_tensor->Shape(), output_dims, output_shape).IsOK()); diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index a0b12790017f6..7c3d400dde9ce 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -14,7 +14,7 @@ class Expand final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; }; -Status ComputeOutputShape( +Status ComputeExpandOutputShape( const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape,