diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 82cc17dd30b2e..003012f8da071 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -37,6 +37,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc index 253a58bd82a20..9008edbf3db30 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc @@ -4,7 +4,6 @@ // Distributed computation. #include "sharding.h" #include "distributed_matmul.h" -#include "nccl_kernels.h" #include "mpi_include.h" // ORT system. @@ -63,20 +62,7 @@ static TensorShape InferMatmulOutputShape( }; template -DistributedMatMul::DistributedMatMul(const OpKernelInfo& info) : NcclKernel(info) { - std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); - std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); - std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); - std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); - - for (size_t i = 0; i < input_shard_specs.size(); ++i) { - auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); - input_shard_specs_.push_back(spec); - } - for (size_t i = 0; i < output_shard_specs.size(); ++i) { - auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); - output_shard_specs_.push_back(spec); - } +DistributedMatMul::DistributedMatMul(const OpKernelInfo& info) : DistributedKernel(info) { } template diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h index d8df24c03498f..da07f9a8b2c7b 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - -#include "sharding_spec.h" -#include "core/providers/cuda/cuda_kernel.h" +#include "sharding.h" #include #include @@ -20,15 +18,11 @@ namespace cuda { #if defined(ORT_USE_NCCL) template -class DistributedMatMul final : public NcclKernel { +class DistributedMatMul final : public DistributedKernel { public: explicit DistributedMatMul(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; - - private: - std::vector input_shard_specs_; - std::vector output_shard_specs_; }; #endif diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc new file mode 100644 index 0000000000000..5768dba791292 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_slice.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) +template +DistributedSlice::DistributedSlice(const OpKernelInfo& info) : DistributedKernel(info) { +} + +template +Status DistributedSlice::ComputeInternal(OpKernelContext* context) const { + const auto tensor_shard_data = context->Input(0); + const auto tensor_shard_starts = context->Input(1); + const auto tensor_shard_ends = context->Input(2); + + const TensorPartitionSpec& spec_data = input_shard_specs_[0]; + const TensorPartitionSpec& spec_starts = input_shard_specs_[1]; + const TensorPartitionSpec& spec_ends = input_shard_specs_[2]; + const TensorPartitionSpec& spec_Y = output_shard_specs_[0]; + + const auto tensor_shard_axes = context->Input(3); + const TensorPartitionSpec& spec_axes = input_shard_specs_[3]; + + if (spec_starts.HasShard() || + spec_ends.HasShard() || + spec_axes.HasShard() || + (input_shard_specs_.size() > 4 && input_shard_specs_[4].HasShard())) + ORT_THROW("DistributedSlice: shard on starts / ends / axes / steps are not supported yet."); + + std::vector input_starts; + std::vector input_ends; + auto starts_data = tensor_shard_starts->DataAsSpan(); + input_starts.resize(starts_data.size()); + std::copy(starts_data.begin(), starts_data.end(), input_starts.begin()); + auto ends_data = tensor_shard_ends->DataAsSpan(); + input_ends.resize(ends_data.size()); + std::copy(ends_data.begin(), ends_data.end(), input_ends.begin()); + + std::vector input_axes; + if (tensor_shard_axes) { + auto axes_data = tensor_shard_axes->DataAsSpan(); + input_axes.resize(axes_data.size()); + std::copy(axes_data.begin(), axes_data.end(), input_axes.begin()); + } + + std::vector input_steps; + const auto tensor_shard_steps = context->Input(4); + if (tensor_shard_steps) { + const TensorPartitionSpec& spec_steps = input_shard_specs_[4]; + if (spec_steps.HasShard()) + ORT_THROW("Not supported yet."); + + auto steps_data = tensor_shard_steps->DataAsSpan(); + input_steps.resize(steps_data.size()); + std::copy(steps_data.begin(), steps_data.end(), input_steps.begin()); + } + + if (spec_data.GetPartitionAxis() != -1 && + std::find(input_axes.begin(), input_axes.end(), spec_data.GetPartitionAxis()) != input_axes.end()) { + // shard on slice axes, reshard first + auto tmp_spec_data = TensorPartitionSpec::CreateAllReplica(spec_data); + auto tensor_data = ReshardTensor(this, context, spec_data, tmp_spec_data, nccl_->Rank(), tensor_shard_data); + + const auto& input_shape = tensor_data->Shape(); + const auto input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + TensorShape output_shape(compute_metadata.output_dims_); + + if (spec_Y.HasNoShard()) { + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_data.get(), + input_starts, + input_ends, + input_axes, + input_steps, + context->Output(0, output_shape))); + } else { + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst_tensor = Tensor::Create(tensor_data->DataType(), output_shape, alloc); + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_data.get(), + input_starts, + input_ends, + input_axes, + input_steps, + dst_tensor.get())); + auto tmp_spec_output = TensorPartitionSpec::CreateAllReplica(spec_Y); + ReshardTensor(this, context, tmp_spec_output, spec_Y, nccl_->Rank(), dst_tensor.get(), 0); + } + } else { + const auto& input_shape = tensor_shard_data->Shape(); + const auto input_dimensions = input_shape.GetDims(); + if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars"); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata)); + TensorShape output_shape(compute_metadata.output_dims_); + + if (spec_Y.GetPartitionAxis() == spec_data.GetPartitionAxis()) { + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_shard_data, + input_starts, + input_ends, + input_axes, + input_steps, + context->Output(0, output_shape))); + } else { + AllocatorPtr alloc; + ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst_tensor = Tensor::Create(tensor_shard_data->DataType(), output_shape, alloc); + ORT_RETURN_IF_ERROR(FuncSlice(this, + context, + tensor_shard_data, + input_starts, + input_ends, + input_axes, + input_steps, + dst_tensor.get())); + ReshardTensor(this, context, spec_data, spec_Y, nccl_->Rank(), dst_tensor.get(), 0); + } + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSlice, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .InputMemoryType(OrtMemTypeCPUInput, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + DistributedSlice); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedSlice, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 1) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .InputMemoryType(OrtMemTypeCPUInput, 3) + .InputMemoryType(OrtMemTypeCPUInput, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + DistributedSlice); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h new file mode 100644 index 0000000000000..48c77eee241de --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_slice.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sharding.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedSlice final : public DistributedKernel { + public: + explicit DistributedSlice(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index d9f2f3c1bcbca..7d106fd75e2d0 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -212,6 +212,46 @@ std::unique_ptr ReshardTensor( return dst; } +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + int output_idx) { + // Implement ReshardTensor but returning a unique_ptr to Tensor instead. + const auto origin_shape = ComputeOriginShape(src->Shape(), src_spec); + const auto dst_shape = ComputeShardShape(origin_shape, dst_spec); + ORT_ENFORCE(CanShard(origin_shape, dst_spec), "Cannot shard tensor. Shape:", origin_shape, ", sharding spec: ", dst_spec.ToString()); + + auto* dst = ctx->Output(output_idx, dst_shape); + ReshardTensor( + nccl_kernel, + ctx, + src_spec, + dst_spec, + device_id, + src, + dst); +} + +DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) { + std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); + std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); + std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); + std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); + + for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); + input_shard_specs_.push_back(spec); + } + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); + output_shard_specs_.push_back(spec); + } +} + #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.h b/onnxruntime/contrib_ops/cuda/collective/sharding.h index 497826160aaab..81a0f72f0c32f 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.h @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "sharding_spec.h" #include "nccl_kernels.h" -#pragma once - namespace onnxruntime { namespace contrib { namespace cuda { @@ -49,6 +48,16 @@ void ReshardTensor( const Tensor* src, Tensor* dst); +// Output from ctx +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + int output_idx); + std::unique_ptr ReshardTensor( const NcclKernel* nccl_kernel, OpKernelContext* ctx, @@ -57,6 +66,17 @@ std::unique_ptr ReshardTensor( const int64_t device_id, const Tensor* src); +class TensorPartitionSpec; + +class DistributedKernel : public NcclKernel { + public: + explicit DistributedKernel(const OpKernelInfo& info); + + protected: + std::vector input_shard_specs_; + std::vector output_shard_specs_; +}; + #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 13982ee7711cf..0f5ef6927a545 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "core/common/common.h" #include "core/framework/tensor_shape.h" @@ -8,8 +9,6 @@ #include #include -#pragma once - namespace onnxruntime { namespace contrib { namespace cuda { diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 71ee5ae1ddbe6..3e440a091870a 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -153,6 +153,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllT class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); #endif template <> @@ -310,6 +313,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 84eed7fae6ac1..7cdd71014c02e 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -105,6 +105,74 @@ void RegisterCollectiveOps() { "tensor(float)", }, "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedSlice) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("device_mesh_elements", + "", + AttributeProto::INTS) + .Attr("device_mesh_shape", + "", + AttributeProto::INTS) + .Attr("input_shard_specs", + "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + AttributeProto::STRINGS) + .Input( + 0, + "data", + "Tensor of data to extract slices from.", + "T", + OpSchema::Single, + true, + 1, + OpSchema::Differentiable) + .Input( + 1, + "starts", + "1-D tensor of starting indices of corresponding axis in `axes`", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 2, + "ends", + "1-D tensor of ending indices (exclusive) of corresponding axis in `axes`", + "Tind", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 3, + "axes", + "1-D tensor of axes that `starts` and `ends` apply to. Negative value means counting dimensions " + "from the back. Accepted range is [-r, r-1] where r = rank(data). Behavior is undefined if an " + "axis is repeated.", + "Tind", + OpSchema::Optional, + true, + 1, + OpSchema::NonDifferentiable) + .Input( + 4, + "steps", + "1-D tensor of slice step of corresponding axis in `axes`. " + "Negative value means slicing backward. 'steps' cannot be 0. " + "Defaults to 1s.", + "Tind", + OpSchema::Optional, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") + .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); } } // namespace contrib diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index 7f3cbc254969e..1baec80cb7c45 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -6,7 +6,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT +from onnxscript import FLOAT, INT64 import onnxruntime as ort @@ -18,7 +18,7 @@ def shard_tensor(X, rank, axis, num_shards): return np.split(X, num_shards, axis)[rank] -class TestDistributedMatMul(unittest.TestCase): +class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): @onnxscript.script() def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: @@ -312,6 +312,99 @@ def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: expected = np.matmul(tensor_x, tensor_w) np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + def test_slice_sr_axis1(self): + @onnxscript.script() + def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: + return MICROSOFT_OPSET.DistributedSlice( + tensor_x, + tensor_starts, + tensor_ends, + tensor_axes, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["S[0]R", "R", "R", "R", "R"], + output_shard_specs=["S[0]R"], + ) + + rank = comm.Get_rank() + # Shape [2, 4] + tensor_x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32) + tensor_starts = np.array([0], dtype=np.int64) + tensor_ends = np.array([2], dtype=np.int64) + tensor_axes = np.array([1], dtype=np.int64) + + onnx_model = slice_sr_axis1.to_model_proto( + input_types=[FLOAT[1, 4], INT64[1], INT64[1], INT64[1]], + output_types=[FLOAT[1, 2]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=0, num_shards=2) + + result = sess.run( + None, + { + "tensor_x": tensor_shard_x, + "tensor_starts": tensor_starts, + "tensor_ends": tensor_ends, + "tensor_axes": tensor_axes, + }, + ) + + expected = tensor_shard_x[:, 0:2] + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_slice_rs_axis1(self): + @onnxscript.script() + def slice_sr_axis1(tensor_x: FLOAT, tensor_starts: INT64, tensor_ends: INT64, tensor_axes: INT64) -> FLOAT: + return MICROSOFT_OPSET.DistributedSlice( + tensor_x, + tensor_starts, + tensor_ends, + tensor_axes, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RS[0]", "R", "R", "R", "R"], + output_shard_specs=["RS[0]"], + ) + + rank = comm.Get_rank() + # Shape [2, 4] + tensor_x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float32) + tensor_starts = np.array([0], dtype=np.int64) + tensor_ends = np.array([2], dtype=np.int64) + tensor_axes = np.array([1], dtype=np.int64) + + onnx_model = slice_sr_axis1.to_model_proto( + input_types=[FLOAT[2, 2], INT64[1], INT64[1], INT64[1]], + output_types=[FLOAT[2, 1]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=1, num_shards=2) + result = sess.run( + None, + { + "tensor_x": tensor_shard_x, + "tensor_starts": tensor_starts, + "tensor_ends": tensor_ends, + "tensor_axes": tensor_axes, + }, + ) + + expected = tensor_x[:, 0:2][:, rank : rank + 1] + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + if __name__ == "__main__": unittest.main()