Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison of Deepspeed Stage 1,2 and 3 vs DDP #4815

Closed
jpatel-bdai opened this issue Dec 14, 2023 · 10 comments
Closed

Comparison of Deepspeed Stage 1,2 and 3 vs DDP #4815

jpatel-bdai opened this issue Dec 14, 2023 · 10 comments
Assignees
Labels
bug Something isn't working training

Comments

@jpatel-bdai
Copy link

Describe the bug
When the model fits on a single GPU, how does Deepspeed ZeRO stage 1 compare with DDP? In my experiments, the Deepspeed ZeRO stage 1. I see that my overall loss training progresses similarly in both the cases but after a few iterations, the Deepspeed ZeRO stage 1 and stage 2 performance degrades.

Expected behavior
I would expect both DDP and Deepspeed ZeRO Stage 1 to give similar results when run of single GPU. The total loss is a combination of a few losses and one of which is trans loss. Do you have experiments that compare DDP and Deepspeed ZeRO stage 1 or 2 that I can refer. Are these supposed to give similar performance? The attached screenshots are for single GPU and 2 GPU experiments for total loss and trans loss.

Screenshots
image
image
image
image

System info (please complete the following information):

  • OS: [e.g. Ubuntu 20.04]
  • A100s -> singe GPU and 2 - GPU
  • Python version - 3.10

Docker context
No

@jpatel-bdai jpatel-bdai added bug Something isn't working training labels Dec 14, 2023
@tjruwase
Copy link
Contributor

@jpatel-bdai, all zero stages are expected to match ddp on single gpu runs. So, it appears that you are hitting bugs in zero.

Are you able to share detailed steps to help us repro? Thanks!

@jpatel-bdai
Copy link
Author

jpatel-bdai commented Dec 14, 2023

I will try to share the detailed steps to reproduce if possible. I am using the pytorch-lightning's Deepspeed Strategy. However, are all zero stages expected to match ddp on multi-gpu runs as well? What are the ways to debug the comparison if I am unable to share the code?

@tjruwase
Copy link
Contributor

Ideally, we expect zero stages to match ddp in multi-gpu runs, since zero is designed to be a memory-efficient ddp algorithm. In terms of debugging, a first step would be to inspect the training loss of each forward pass to detect deviations.

@chiragjn
Copy link
Contributor

chiragjn commented Dec 21, 2023

Don't want to hijack this issue, but I noticed that my train loss values are wildly different between stage 2 and stage 3, is that expected? I take that minor differences can happen because of different optimizer implementations but the differences in my case is too severe - I checked that everything was seeded the same way and with multiple restarts of stage 2 and stage 3 results were not exact but consistent with the same stage but not across

blue, green = stage 2
red = stage 3
image

@jpatel-bdai
Copy link
Author

@tjruwase I have an issue registered here Lightning-AI/pytorch-lightning#19246 but it looks like the issue is from Deepspeed.
I verified that the modules are initialized with same weights and set deterministic=True in Trainer() but still looks like the DDP and Deepspeed loss values do not match on a single GPU. The issue I am facing currently is as below:
I have
tensor_x = torch.nn.Parameter(torch.zeros((dim_a, dim_b)))
in my model initialization and the following in the forward pass
tensor_x.data = torch.nn.functional.normalize(tensor_x.data, dim=-1)
During the forward pass few of the tensor_x.data values are dissimilar at the 3rd decimal (ex: 12.04345 and 12.04556) and majority of the values are exactly same. But, this impacts the performance of the model. As the training progress, the losses in case of Deepspeed do not go as low as DDP. This is with deepspeed_stage_1. Do you have any potential directions I could look into?

Here is the sample script where I tried to compare DDP and Deepspeed with a simple MNIST example on a single GPU. During the backward pass, the model weights are updated differently by the https://deepspeed.readthedocs.io/en/stable/_modules/deepspeed/runtime/zero/stage_1_and_2.html optimizer in Deepspeed vs Adam in DDP.

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
tmpdir = os.getcwd()
from lightning import Trainer, LightningModule, LightningDataModule
from lightning.pytorch.loggers.wandb import WandbLogger

PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
BATCH_SIZE = 256
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())

from lightning.pytorch import Trainer, seed_everything
seed_everything(42, workers=True)

class LitMNIST(LightningModule):

    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, self.num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


class MyDataModule(LightningDataModule):
    def __init__(self, data_dir=PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)


if __name__ == '__main__':
    import time
    timestr = time.strftime("%Y%m%d-%H%M%S")
    strategy_name = "deepspeed_stage_1"
    # wandb_logger = WandbLogger(project="test_model",id=f"test_32_{strategy_name}_{timestr}", log_model="all")
    
    model = LitMNIST()
    datamodule = MyDataModule()
    trainer = Trainer(
        devices=1,
        accelerator="cuda",
        max_epochs=10,
        precision=32,
        strategy=strategy_name,
        # logger=wandb_logger,
    )
    trainer.fit(model, datamodule)

#- Lightning Component (e.g. Trainer, LightningModule): Trainer, LightningModule
#- PyTorch Lightning Version : 2.1.0
#- PyTorch Version: 2.1.0+cu121
#- Python version : Python 3.10.12
#- OS (e.g., Linux): Debian
#- CUDA/cuDNN version: NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2
#- GPU models and configuration: NVIDIA L4 (24GB GPU VRAM)
#- How you installed Lightning(conda, pip, source): pip install lightning

@olegsinavski
Copy link

olegsinavski commented Jan 27, 2024

Hello, I'm debugging the same issue. Since I'm working on VLMs, I found that the inclusion of the vision part (e.g. a timm model) leads to drastically slower convergence, but if I train regular LLMs, both converge well.
Here is a VLM training:
image
Here is an LLM training:
image

I discarded many options already: e.g. I fixed BatchNorms, used bf16-true, added "amp" decorators, used the same optimizers, disabled schedulers but I still see differences in the "vision" case. There are few other differences I'm trying to work out...

It would be great to have an actual bit-to-bit test between ddp and deepspeed!

@GuanhuaWang GuanhuaWang self-assigned this Feb 5, 2024
@jpatel-bdai
Copy link
Author

@GuanhuaWang, @tjruwase and @jomayeri Do you have any findings to share on this? Is there a minimal example comparing DDP and Deepspeed ZeRO where the parameter updates are identical?

@tjruwase tjruwase assigned tohtana and unassigned GuanhuaWang and jomayeri Apr 17, 2024
@tohtana
Copy link
Contributor

tohtana commented Apr 17, 2024

Hi @jpatel-bdai,
We recently fixed multiple accuracy issues (e.g. #5104, #5105, #5150, and #5170). Some users reported that mismatches of loss values were solved by updating from 0.12.* to 0.14.*.
Can you try the latest version if you haven't?

@tohtana
Copy link
Contributor

tohtana commented Apr 17, 2024

@jpatel-bdai Let me share my verification script.

As long as I set FP32, PyTorch's Adam, and NP=2, it showed exact matches with PyTorch.
Currently this does not work well with FP16/BF16. I would appreciate it If you have any idea for improvement.

@tohtana
Copy link
Contributor

tohtana commented Jul 22, 2024

Let me close this issue as we haven't had a new report for a while. Please feel free to reopen it if you still see the issue.

@tohtana tohtana closed this as completed Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

7 participants