Skip to content

Commit

Permalink
[QEff. Finetune] : Sort dataset based on length for DDP (#211)
Browse files Browse the repository at this point in the history
Sorting in dataset, sample based on length for DDP
fix for run_validation = False

---------

Signed-off-by: Mamta Singh <[email protected]>
  • Loading branch information
quic-mamta authored and quic-rishinr committed Jan 10, 2025
1 parent 2fed663 commit de97e39
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 32 deletions.
9 changes: 6 additions & 3 deletions QEfficient/cloud/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,12 @@ def main(**kwargs):
else:
print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")

longest_seq_length, longest_seq_ix = get_longest_seq_length(
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
)
longest_seq_length, _ = get_longest_seq_length(
torch.utils.data.ConcatDataset([train_dataloader.dataset, eval_dataloader.dataset])
)
else:
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)

print(
f"The longest sequence length in the train data is {longest_seq_length}, "
f"passed context length is {train_config.context_length} and overall model's context length is "
Expand Down
1 change: 1 addition & 0 deletions QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class train_config:
save_metrics: bool = True # saves training metrics to a json file for later plotting
intermediate_step_save: int = 1000
batching_strategy: str = "packing"
enable_sorting_for_ddp: bool = "True"

# TODO: vbaddi: Uncomment post adding qaic to Pytorch Profiler
# flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time.
Expand Down
62 changes: 62 additions & 0 deletions QEfficient/finetune/data/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import random
from itertools import islice

import numpy as np
import torch


class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool = True) -> None:
if isinstance(next(iter(data_source)), dict):
first_key = next(iter(next(iter(data_source)).keys()))
self.lengths = [len(d[first_key]) for d in data_source]
else:
self.lengths = [len(d) for d in data_source]
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle = shuffle

def __iter__(self):
ids = np.argsort(self.lengths, kind="mergesort")
if self.drop_last:
ids = ids[: len(ids) // self.batch_size * self.batch_size]

batches = [ids[i : i + self.batch_size] for i in range(0, len(ids), self.batch_size)]

if self.shuffle:
random.shuffle(batches)

for b in batches:
yield b

def __len__(self):
if self.drop_last:
return len(self.lengths) // self.batch_size
else:
return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)


class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
def __init__(
self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0
) -> None:
random.seed(seed)
self.batch_sampler = LengthBasedBatchSampler(
data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
)
self.num_replicas = num_replicas
self.rank = rank

def __iter__(self):
max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)

def __len__(self):
return len(self.batch_sampler) // self.num_replicas
56 changes: 28 additions & 28 deletions QEfficient/finetune/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
PrefixTuningConfig,
)
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq

import QEfficient.finetune.configs.dataset_config as datasets
from QEfficient.finetune.configs.peft_config import lora_config, prefix_config
from QEfficient.finetune.configs.training import train_config
from QEfficient.finetune.data.sampler import DistributedLengthBasedBatchSampler
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC


Expand Down Expand Up @@ -63,41 +65,39 @@ def generate_peft_config(train_config, kwargs):

def generate_dataset_config(train_config, kwargs):
names = tuple(DATASET_PREPROC.keys())

assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"

dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()

update_config(dataset_config, **kwargs)

return dataset_config


# def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
# kwargs = {}
# batch_size = (
# train_config.batch_size_training
# if mode == "train"
# else train_config.val_batch_size
# )
# if train_config.batching_strategy == "padding":
# kwargs["batch_sampler"] = LengthBasedBatchSampler(
# dataset, batch_size, drop_last=True, shuffle=mode == "train"
# )
# kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
# # kwargs["collate_fn"] = default_data_collator
# return kwargs


def get_dataloader_kwargs(train_config: train_config, dataset, dataset_processer, mode):
def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
# use a distributed sampler to split data between devices
if train_config.enable_ddp:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
if train_config.enable_sorting_for_ddp:
if train_config.context_length:
raise ValueError(
"Sorting cannot be done with padding, Please disable sorting or pass context_length as None to disable padding"
)
else:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=False,
)
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
else:
kwargs["sampler"] = data_utils.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
return kwargs
14 changes: 13 additions & 1 deletion docs/source/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Same CLI can be used to run Finetuning on gpu by setting the device flag.(for fi
## Installation

Same as QEfficient along with QAIC PyTorch Eager mode.

For QEfficient Library : https://github.com/quic/efficient-transformers

For torch_qaic, assuming QEfficient is already installed,
```bash
pip install /opt/qti-aic/integrations/torch_qaic/py310/torch_qaic-0.1.0-cp310-cp310-linux_x86_64.whl
Expand Down Expand Up @@ -51,4 +54,13 @@ python -m QEfficient.cloud.finetune --device qaic:0 --use-peft --output_dir ./me
```python
QAIC_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc-per-node 4 -m QEfficient.cloud.finetune --device qaic --enable_ddp --dist_backend qccl --num_epochs 2 --model_name "meta-llama/Llama-3.2-1B"
```
**nproc-per-node is number of workers(gpus) running locally.
**nproc-per-node is number of workers(QAIC devices) running locally.

## Visualization

Tensorboard logs are generated inside runs/ directory with date and time stamp.
to visualise the data,

```python
tensorboard --logdir runs/<file> --bind_all
```

0 comments on commit de97e39

Please sign in to comment.