-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save and resume the state of a DataLoader #5454
Comments
Something that'd be nice to have is "manual update of state". One of the learning from training LLMs is the ability to skip some batches whenever we notice huge spike might be handy. |
Your outline spec is very sound and clear, @lhoestq - thank you! @thomasw21, indeed that would be a wonderful extra feature. In Megatron-Deepspeed we manually drained the dataloader for the range we wanted. I wasn't very satisfied with the way we did it, since its behavior would change if you were to do multiple range skips. I think it should remember all the ranges it skipped and not just skip the last range - since otherwise the data is inconsistent (but we probably should discuss this in a separate issue not to derail this much bigger one). |
Hi there! I think this is a critical issue and have an urgent need for it, in my attempt to train on a super large-scale dataset using @stas00 @thomasw21 @lhoestq Any updates on this problem after 1 year passed? |
any update? |
No update so far, I wonder if someone implemented a resumable pytorch Sampler somwhere. Then regarding resuming a streaming dataset, we'd first like to have an efficient way to skip shards automatically but this is not implemented yet |
I opened a draft here for IterableDataset: #6658 """Requires https://github.com/huggingface/datasets/pull/6658 (WIP)"""
from datasets import load_dataset
from torch.utils.data import DataLoader
ds = load_dataset(..., streaming=True)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42, buffer_size=1000)
# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = ds.state_dict()
# Resumable training loop
ds.load_state_dict(dataset_state_dict)
dataloader = DataLoader(ds, batch_size=batch_size)
for step, batch in enumerate(dataloader):
...
if step % save_steps == 0:
dataset_state_dict = ds.state_dict() |
Hi @lhoestq - can you provide more information and how to implement on saving and restoring vanilla DataLoader states with map-style datasets? |
For now the easiest is probably to use the vanilla DataLoader only for batching and multiprocessing, and implement the resuming logic using a from datasets import load_dataset
from torch.utils.data import DataLoader
ds = load_dataset(...)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42)
# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = {"step": 0}
# Resumable training loop
start_step = dataset_state_dict["step"]
dataloader = DataLoader(ds.select(range(start_step * batch_size, len(ds))), batch_size=batch_size)
for step, batch in enumerate(dataloader, start=start_step):
...
if step % save_steps == 0:
dataset_state_dict = {"step": step} |
Hello, I found a similar implementation online that seems to solve your problem. https://github.com/facebookresearch/vissl/blob/main/vissl/data/data_helper.py#L93 |
Hi y'all, @lhoestq I wanted to flag that we currently have a StatefulDataLoader in |
Fantastic ! This will help pushing our IterableDataset state_dict implementation at #6658 :) I'll check if there is anything missing to maker them work together, and add tests and some docs referring to the StatefulDataLoader :) |
Ah I just saw this disclaimer in the torchdata README and it feels like people should not rely on it. Should the StatefulDataLoader live elsewhere @andrewkho ?
|
@lhoestq Good find, we are in the midst of updating this disclaimer as we're re-starting development and regular releases, though our approach will be to iterate on DL V1 (ie StatefulDataLoader) instead of continuing development on datapipes+DLV2. Let's discuss on a call at some point to figure out the best path forward! |
As a heads up, ...and it works out of the box with the See the docs at https://huggingface.co/docs/datasets/main/en/use_with_pytorch#checkpoint-and-resume |
amazing! Thank you, @lhoestq does it work with non-iterable dataset as well? the docs only mention iterable dataset |
It's for iterable dataset only. For regular dataset I believe the sampler should implement state_dict, but maybe @andrewkho might know best how to resume a regular dataset with torchdata |
@stas00 stateful dataloader will save and resume samplers for map style datasets. If no state_dict/load_state_dict is provided by the sampler, it will naively skip samples to fast forward. See here for more details https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/README.md Hope this helps! |
Thank you very much for clarifying that, Andrew. |
👋 I am trying to use So far, I have been able to use HF Streaming Dataset + TorchDDP with Vanilla Datasets. To do so, I implemented a custom iterable to make sure that shards are distributed across the multiple nodes, while letting the
This doesn't seem to work with I confirmed that if I skip the custom class and pass Is there something I am missing? It would be nice if streaming |
Hi ! Have you tried using dataloader = StatefulDataLoader(split_dataset_by_node(dataset, rank=process_rank, world_size=world_size)) |
@lhoestq it took me some time to test, but it works like a charm. Thanks for the pointer. Totally missed this 🤦. |
It would be nice when using
datasets
with a PyTorch DataLoader to be able to resume a training from a DataLoader state (e.g. to resume a training that crashed)What I have in mind (but lmk if you have other ideas or comments):
For map-style datasets, this requires to have a PyTorch Sampler state that can be saved and reloaded per node and worker.
For iterable datasets, this requires to save the state of the dataset iterator, which includes:
Right now you can already resume the data loading of an iterable dataset by using
IterableDataset.skip
but it takes a lot of time because it re-iterates on all the past data until it reaches the resuming point.cc @stas00 @sgugger
The text was updated successfully, but these errors were encountered: