Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and resume the state of a DataLoader #5454

Open
lhoestq opened this issue Jan 23, 2023 · 21 comments
Open

Save and resume the state of a DataLoader #5454

lhoestq opened this issue Jan 23, 2023 · 21 comments
Labels
enhancement New feature or request generic discussion Generic discussion on the library

Comments

@lhoestq
Copy link
Member

lhoestq commented Jan 23, 2023

It would be nice when using datasets with a PyTorch DataLoader to be able to resume a training from a DataLoader state (e.g. to resume a training that crashed)

What I have in mind (but lmk if you have other ideas or comments):

For map-style datasets, this requires to have a PyTorch Sampler state that can be saved and reloaded per node and worker.

For iterable datasets, this requires to save the state of the dataset iterator, which includes:

  • the current shard idx and row position in the current shard
  • the epoch number
  • the rng state
  • the shuffle buffer

Right now you can already resume the data loading of an iterable dataset by using IterableDataset.skip but it takes a lot of time because it re-iterates on all the past data until it reaches the resuming point.

cc @stas00 @sgugger

@lhoestq lhoestq added enhancement New feature or request generic discussion Generic discussion on the library labels Jan 23, 2023
@thomasw21
Copy link
Contributor

Something that'd be nice to have is "manual update of state". One of the learning from training LLMs is the ability to skip some batches whenever we notice huge spike might be handy.

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2023

Your outline spec is very sound and clear, @lhoestq - thank you!

@thomasw21, indeed that would be a wonderful extra feature. In Megatron-Deepspeed we manually drained the dataloader for the range we wanted. I wasn't very satisfied with the way we did it, since its behavior would change if you were to do multiple range skips. I think it should remember all the ranges it skipped and not just skip the last range - since otherwise the data is inconsistent (but we probably should discuss this in a separate issue not to derail this much bigger one).

@yqy2001
Copy link

yqy2001 commented Jan 25, 2024

Hi there! I think this is a critical issue and have an urgent need for it, in my attempt to train on a super large-scale dataset using datasets. It is impossible to resume a time-consuming (like one month) experiment by iterating all seen data again, which could possibly cost several days.

@stas00 @thomasw21 @lhoestq Any updates on this problem after 1 year passed?

@dancingpipi
Copy link

any update?

@lhoestq
Copy link
Member Author

lhoestq commented Feb 2, 2024

No update so far, I wonder if someone implemented a resumable pytorch Sampler somwhere.

Then regarding resuming a streaming dataset, we'd first like to have an efficient way to skip shards automatically but this is not implemented yet

@lhoestq
Copy link
Member Author

lhoestq commented Feb 19, 2024

I opened a draft here for IterableDataset: #6658

"""Requires https://github.com/huggingface/datasets/pull/6658 (WIP)"""
from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(..., streaming=True)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42, buffer_size=1000)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = ds.state_dict()

# Resumable training loop
ds.load_state_dict(dataset_state_dict)
dataloader = DataLoader(ds, batch_size=batch_size)
for step, batch in enumerate(dataloader):
    ...
    if step % save_steps == 0:
        dataset_state_dict = ds.state_dict()

@jwliu36
Copy link

jwliu36 commented Feb 21, 2024

Hi @lhoestq - can you provide more information and how to implement on saving and restoring vanilla DataLoader states with map-style datasets?

@lhoestq
Copy link
Member Author

lhoestq commented Feb 21, 2024

For now the easiest is probably to use the vanilla DataLoader only for batching and multiprocessing, and implement the resuming logic using a Dataset (it has .select() to skip examples) and a dataset_state_dict:

from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(...)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = {"step": 0}  

# Resumable training loop
start_step = dataset_state_dict["step"]
dataloader = DataLoader(ds.select(range(start_step * batch_size, len(ds))), batch_size=batch_size)
for step, batch in enumerate(dataloader, start=start_step):
    ...
    if step % save_steps == 0:
        dataset_state_dict = {"step": step}

@xgbj
Copy link

xgbj commented Mar 19, 2024

Hello, I found a similar implementation online that seems to solve your problem. https://github.com/facebookresearch/vissl/blob/main/vissl/data/data_helper.py#L93
it looks like we can set_start_iter in StatefulDistributedSampler to implement the stateful resume requirement we want.

@andrewkho
Copy link

Hi y'all, @lhoestq I wanted to flag that we currently have a StatefulDataLoader in pytorch/data/torchdata that has state_dict/load_state_dict methods, which will call a dataset's state_dict/load_state_dict methods but also handle multiprocessing under the hood. Any chance we can collaborate on this and try to get them to work well together? Please have a look here for some basic examples: https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader#saving-and-loading-state

@lhoestq
Copy link
Member Author

lhoestq commented Apr 30, 2024

Fantastic ! This will help pushing our IterableDataset state_dict implementation at #6658 :) I'll check if there is anything missing to maker them work together, and add tests and some docs referring to the StatefulDataLoader :)

@lhoestq
Copy link
Member Author

lhoestq commented Apr 30, 2024

Ah I just saw this disclaimer in the torchdata README and it feels like people should not rely on it. Should the StatefulDataLoader live elsewhere @andrewkho ?

⚠️ As of July 2023, we have paused active development on TorchData and have paused new releases. We have learnt a lot from building it and hearing from users, but also believe we need to re-evaluate the technical design and approach given how much the industry has changed since we began the project. During the rest of 2023 we will be re-evaluating our plans in this space. Please reach out if you suggestions or comments (please use pytorch/data#1196 for feedback).

@andrewkho
Copy link

@lhoestq Good find, we are in the midst of updating this disclaimer as we're re-starting development and regular releases, though our approach will be to iterate on DL V1 (ie StatefulDataLoader) instead of continuing development on datapipes+DLV2. Let's discuss on a call at some point to figure out the best path forward!

@lhoestq
Copy link
Member Author

lhoestq commented Jul 22, 2024

As a heads up, IterableDataset state_dict has been added in #6658

...and it works out of the box with the torchdata StatefulDataLoader :)

See the docs at https://huggingface.co/docs/datasets/main/en/use_with_pytorch#checkpoint-and-resume

@stas00
Copy link
Contributor

stas00 commented Jul 24, 2024

amazing! Thank you, @lhoestq

does it work with non-iterable dataset as well? the docs only mention iterable dataset

@lhoestq
Copy link
Member Author

lhoestq commented Jul 25, 2024

It's for iterable dataset only. For regular dataset I believe the sampler should implement state_dict, but maybe @andrewkho might know best how to resume a regular dataset with torchdata

@andrewkho
Copy link

andrewkho commented Jul 25, 2024

@stas00 stateful dataloader will save and resume samplers for map style datasets. If no state_dict/load_state_dict is provided by the sampler, it will naively skip samples to fast forward. See here for more details https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/README.md

Hope this helps!

@stas00
Copy link
Contributor

stas00 commented Jul 25, 2024

Thank you very much for clarifying that, Andrew.

@diego-andres-ardila
Copy link

👋 I am trying to use HF Streaming Dataset + TorchDDP + Stateful Dataloader, to train using multiple nodes and large datasets.

So far, I have been able to use HF Streaming Dataset + TorchDDP with Vanilla Datasets. To do so, I implemented a custom iterable to make sure that shards are distributed across the multiple nodes, while letting the dataset take care of the multiple workers. The implementation uses split_dataset_by_node:

import torch
from torch.distributed import get_rank, get_world_size
from torch.utils.data import DataLoader, IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self._iterable_by_node = None

    def __iter__(self):
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            world_size = get_world_size()
            process_rank = get_rank()
        else:
            world_size = 1
            process_rank = 0

        if world_size > 1:
            self._iterable_by_node = split_dataset_by_node(
                    self.dataset, rank=process_rank, world_size=world_size
                )
        else:
            self._iterable_by_node = self.dataset

        for example in self._iterable_by_node:
            # Trying with _state_dict, since `.state_dict()` creates a copy
            self._state_dict.update(self._iterable_by_node._state_dict)
            yield example

    def state_dict(self):
        return self._state_dict

    def load_state_dict(self, state):
        pass # Not implemented yet

This doesn't seem to work with StatefulDataLoader though. I can see the state of the worker's dataset being updated in its corresponding workers' processes, but somehow the updates are not propagated back to the main process. I have tried with different variants of the above code without success.

I confirmed that if I skip the custom class and pass dataset directly to the loader as in the docs, the StatefulDataLoader sees the updates for each worker. However, if I do this, multiple nodes will see the same examples, which I definitely don't want.

Is there something I am missing? It would be nice if streaming datasets would support by default the multinode training (unless it already does it and I am missing something).

@lhoestq
Copy link
Member Author

lhoestq commented Nov 25, 2024

Hi ! Have you tried using split_dataset_by_node() and pass the result to the StatefulDataLoader ?

dataloader = StatefulDataLoader(split_dataset_by_node(dataset, rank=process_rank, world_size=world_size))

@diego-andres-ardila
Copy link

Hi ! Have you tried using split_dataset_by_node() and pass the result to the StatefulDataLoader ?

@lhoestq it took me some time to test, but it works like a charm. Thanks for the pointer. Totally missed this 🤦.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request generic discussion Generic discussion on the library
Projects
None yet
Development

No branches or pull requests

9 participants