Skip to content

Commit

Permalink
Parallel map step for DistributedDataAnalyzer map-reduce (#5291)
Browse files Browse the repository at this point in the history
- adds multi CPU-processing to the `DistributedDataAnalyzer` map
operation (parallelism set with parameter `num_workers`). Works with a
`SharedMemory` / `Manager's` queue per metric, written concurrently by
processes.
- much faster `write_buffer_to_file` in `DistributedDataAnalyzer` reduce
operation by copying to cpu and "detaching" output tensor.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Conglong Li <[email protected]>
  • Loading branch information
3 people authored Apr 18, 2024
1 parent aaaf8bc commit 64defe6
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 41 deletions.
128 changes: 90 additions & 38 deletions deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import os
import sys
from collections import defaultdict
import csv
import time
Expand All @@ -12,8 +13,8 @@
import torch
from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset

from deepspeed.utils import logger
import deepspeed.comm as dist
from deepspeed.utils import logger
from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset import MMapIndexedDataset, valid_dtypes
from deepspeed.runtime.data_pipeline.data_sampling.utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype

Expand Down Expand Up @@ -457,6 +458,7 @@ def __init__(
self,
dataset,
num_workers=1,
num_threads=1,
worker_id=0,
batch_size=1,
metric_names=[],
Expand All @@ -477,6 +479,8 @@ def __init__(
self.collate_fn = collate_fn
self.device = device
self.sample_indices = sample_indices
self.num_threads = num_threads
self.worker_id = worker_id

if not dist.is_initialized():
dist.init_distributed()
Expand All @@ -494,13 +498,9 @@ def __init__(
if self.worker_id == 0:
logger.info(f"Distributed data analyzer initialized with {self.num_workers} workers.")

def run_map_reduce(self):

# setup individual dataloaders
worker_splits, _ = split_dataset(self.dataset, self.num_workers, self.worker_id, num_threads=1)
start_idx, end_idx = worker_splits[self.worker_id]
logger.info(f"worker {self.worker_id}: start working on data subset {start_idx} to {end_idx}")
worker_dataset = Subset(self.dataset, list(range(start_idx, end_idx)))
def run_map_helper(self, thread_id=0, metric_queues=None):
thread_start_idx, thread_end_idx = self.thread_splits[thread_id][0], self.thread_splits[thread_id][1]
worker_dataset = Subset(self.dataset, list(range(thread_start_idx, thread_end_idx)))
sampler = BatchSampler(SequentialSampler(worker_dataset), batch_size=self.batch_size, drop_last=False)
dataloader = DataLoader(dataset=worker_dataset,
batch_sampler=sampler,
Expand All @@ -516,7 +516,7 @@ def run_map_reduce(self):
metric_results.append([] if metric_type == 'single_value_per_sample' else None)

# iterate dataloader and store metric results
batch_start_idx = start_idx
batch_start_idx = thread_start_idx
for data in dataloader:
for m_idx in range(len(self.metric_names)):
metric_type, metric_function = self.metric_types[m_idx], self.metric_functions[m_idx]
Expand Down Expand Up @@ -544,15 +544,73 @@ def run_map_reduce(self):
metric_results[m_idx].add_(metric_values)
batch_start_idx += len(data)

if self.num_threads == 1:
return metric_results

# copy metric_results to the shared queue
assert metric_queues
for m_idx in range(len(self.metric_names)):
results = metric_results[m_idx]
if torch.is_tensor(results):
results = results.item() if results.dim() == 0 else results.tolist()
try:
metric_queues[m_idx].put((thread_id, results))
except Exception as e:
logger.error(f"Error putting metric results to queue: {e}")
sys.exit(1)

def run_map_reduce(self):

# setup individual dataloaders
self.worker_splits, self.thread_splits = split_dataset(self.dataset,
self.num_workers,
self.worker_id,
num_threads=self.num_threads)
node_start_idx, node_end_idx = self.worker_splits[self.worker_id]
logger.info(f"worker {self.worker_id} working on data subset {node_start_idx} to {node_end_idx}.")

if self.num_threads in [0, 1, None]:
metric_results = self.run_map_helper()
metric_results = [torch.tensor(m).to(self.device) for m in metric_results]
else:

# create a shared queue of results per metric to be populated by individual threads
with Manager() as manager:
metric_queues = [manager.Queue() for _ in self.metric_names]
threads = [
Process(target=self.run_map_helper, args=(t, metric_queues)) for t in range(self.num_threads)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

# gather results from shared queues into metric_results
metric_results = [None for _ in self.metric_names]
for m_idx, (queue, metric_type) in enumerate(zip(metric_queues, self.metric_types)):
while not queue.empty():
t_idx, t_results = queue.get()
t_start_idx, t_end_idx = self.thread_splits[t_idx]
if t_start_idx >= t_end_idx: # no results from this thread
continue #corner case for small datasets and high thread count
t_results = torch.tensor(t_results)
if metric_type == 'single_value_per_sample':
# add thread results to the metric_results list, ordered by thread idx
if metric_results[m_idx] is None: # initialize if needed
metric_results[m_idx] = torch.zeros(node_end_idx - node_start_idx,
t_results.size(1)).to(self.device)
metric_results[m_idx][t_start_idx - node_start_idx:t_end_idx - node_start_idx] = t_results
else:
if metric_results[m_idx] is None: # initialize if needed
metric_results[m_idx] = torch.zeros(t_results.size()).to(self.device)
metric_results[m_idx].add_(t_results)

# compute dtype for sample ids
total_num_samples = len(self.dataset)
sample_idx_dtype = find_fit_int_dtype(0, total_num_samples - 1)
logger.info(f"Total number of data samples: {total_num_samples}.")
logger.info(f"Will use {sample_idx_dtype} to store the sample indexes.")

# convert to list of tensors
metric_results = [torch.tensor(m).to(self.device) for m in metric_results]

for m_idx in range(len(self.metric_names)):
metric_values, metric_name, metric_type = \
metric_results[m_idx], self.metric_names[m_idx], self.metric_types[m_idx]
Expand Down Expand Up @@ -611,8 +669,8 @@ def run_map_reduce(self):
def file_write_ordered(self, tensor_list, fname, numpy_dtype):
""" MPI_file_write_ordered extended to write a list of tensors, by one rank, iteratively """

# each not has a list of rows (tensors) to be written to the file.
# we will serialize it to communicate it in one comm step.
# each node has a list of rows (tensors) to be written to the file.
# we will serialize it in order to communicate it in one comm step.

tkwargs = dict(dtype=torch.int64, device=self.device)

Expand All @@ -636,17 +694,13 @@ def file_write_ordered(self, tensor_list, fname, numpy_dtype):
def write_buffer_to_file(buff, src, builder):
assert self.worker_id == 0, "only rank 0 can write to file"

# # write one buffer at a time
# for row_len in row_lens[src]:
# builder.add_item(buff[:row_len].cpu())
# buff = buff[row_len:]

# collect all buffers and write them all at once
buffer_list = []
for row_len in row_lens[src]:
buffer_list.append(buff[:row_len].cpu())
buff = buff[row_len:]
builder.add_items(buffer_list)
# collect all buffers and write them at once
buff = buff.cpu().detach().numpy()
row_offsets = np.cumsum([0] + row_lens[src].tolist())
arr_list = []
for i in range(len(row_lens[src])):
arr_list.append(buff[row_offsets[i]:row_offsets[i + 1]])
builder.add_items(arr_list)

# 5. rank 0 prepares output folder and file
if self.worker_id == 0:
Expand Down Expand Up @@ -700,7 +754,7 @@ def gather_v(tensor, dst, comm_group, num_workers, worker_id):
# all_gather requires all tensors to be of same size so we need to pad them
max_size = max(sizes).item()
buffer = torch.empty(max_size, dtype=tensor.dtype, device=tensor.device)
buffer[0:size] = torch.tensor(tensor, dtype=tensor.dtype, device=tensor.device)
buffer[0:size] = tensor.data
buffer_list = None
if worker_id == 0: # create padded recv buffers
buffer_list = [torch.empty(max_size, dtype=tensor.dtype, device=tensor.device) for _ in range(num_workers)]
Expand Down Expand Up @@ -763,16 +817,18 @@ def sample_sort(tensor, comm_group, num_workers, n_samples=100):
def test_compare_both_data_analyzers(dataset):
""" given a dataset, compare file and memory based data analyser"""

id = lambda t: torch.tensor(t).to(torch.int64) # identity
id = lambda t: t.to(torch.int64) # identity
batch_sum = lambda t: id(t).sum() #sum batch
num_threads = 4
kwargs = dict(
dataset=dataset,
batch_size=3,
batch_size=2**10,
worker_id=int(os.environ['RANK']),
num_workers=int(os.environ['WORLD_SIZE']),
metric_names=["mod", "batch_sum"],
metric_functions=[id, batch_sum],
metric_types=['single_value_per_sample', 'accumulate_value_over_samples'],
num_threads=num_threads,
)

dda = DistributedDataAnalyzer(
Expand All @@ -785,10 +841,9 @@ def test_compare_both_data_analyzers(dataset):
if dda.worker_id == 0:
print("DistributedDataAnalyzer runtime: %s seconds " % (time.time() - start_time))

da = DataAnalyzer(num_threads=2,
num_threads_reduce=2,
metric_dtypes=[torch.int64, torch.int64],
da = DataAnalyzer(num_threads_reduce=num_threads,
save_path="./output_disk",
metric_dtypes=[torch.int64, torch.int64],
**kwargs)
start_time = time.time()
da.run_map_reduce()
Expand All @@ -815,14 +870,11 @@ def test_compare_both_data_analyzers(dataset):

class TestDataset(torch.utils.data.Dataset):

def __init__(self, size=20):
self.values = [1001 + x % 6 for x in range(size)]
def __init__(self, size=10_000_000):
self.values = [(x + 7) % 10_000 for x in range(size)]
self.size = size

def __len__(self):
return self.size

def __getitem__(self, idx):
return self.values[idx]
__len__ = lambda self: self.size
__getitem__ = lambda self, idx: self.values[idx]

test_compare_both_data_analyzers(TestDataset())
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ def add_item(self, tensor):
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)

def add_items(self, tensor_list):
""" write a list of tensors to the file and update their sizes in the index"""
np_arrays = [np.array(t.numpy(), dtype=self._dtype) for t in tensor_list]
def add_items(self, arr_list):
""" write a list of arrays to the file and update their sizes in the index"""
np_arrays = [arr.astype(self._dtype) for arr in arr_list]
self._data_file.writelines([arr.tobytes(order='C') for arr in np_arrays])
for arr in np_arrays:
self._sizes.append(arr.size)
Expand Down

0 comments on commit 64defe6

Please sign in to comment.