diff --git a/deepspeed/runtime/data_pipeline/data_sampling/utils.py b/deepspeed/runtime/data_pipeline/data_sampling/utils.py index 9c643f3705de..dc55f96e222d 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/utils.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/utils.py @@ -3,7 +3,6 @@ # DeepSpeed Team -import math import numpy as np from deepspeed.utils import logger @@ -32,10 +31,8 @@ def find_fit_int_dtype(min_value, max_value): def split_index(start_idx, end_idx, num_partitions): - partition_size = math.ceil((end_idx - start_idx) / num_partitions) - partitions = [[start_idx + x * partition_size, - min(end_idx, start_idx + (x + 1) * partition_size)] for x in range(num_partitions)] - return partitions + partition_boundaries = np.linspace(start_idx, end_idx, dtype=int, num=num_partitions + 1) + return [(partition_boundaries[i], partition_boundaries[i + 1]) for i in range(num_partitions)] def split_dataset(dataset, num_workers, worker_id, num_threads):