Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMP BF16 issue with batch norm layer #8496

Closed
lukeliu15 opened this issue Dec 17, 2024 · 0 comments · Fixed by #8498 or #8556
Closed

AMP BF16 issue with batch norm layer #8496

lukeliu15 opened this issue Dec 17, 2024 · 0 comments · Fixed by #8498 or #8556

Comments

@lukeliu15
Copy link

🐛 Bug

The following runtime error is raised when using autocast dtype=torch.bfloat16 with conv + batchnorm layer:

RuntimeError: Bad StatusOr access: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %batch-norm-training.15 = (bf16[4,16,32,32]{3,2,1,0}, bf16[16]{0}, bf16[16]{0}) batch-norm-training(bf16[4,16,32,32]{3,2,1,0} %add.14, f32[16]{0} %p2.3, f32[16]{0} %p3.4), epsilon=1e-05, feature_index=1, but mixed precision is disallowed.

This issue is NOT reproducible when using dtype=torch.float16 or using torch.cuda.amp.autocast without XLA.

To Reproduce

import os

os.environ["XLA_REGISTER_INSTALLED_PLUGINS"] = "1"

import torch
from torch import nn
from torch_xla.amp import autocast
import torch_xla.core.xla_model as xm

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(16)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

def main():
    device = xm.xla_device()
    model = SimpleModel().to(device)
    inputs = torch.randn(4, 3, 32, 32).to(device)

    with autocast(device, dtype=torch.bfloat16):
        output = model(inputs)

    xm.mark_step()


if __name__ == '__main__':
    main()

Expected behavior

Above code should run without error.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 2.3.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant