You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The following runtime error is raised when using autocast dtype=torch.bfloat16 with conv + batchnorm layer:
RuntimeError: Bad StatusOr access: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %batch-norm-training.15 = (bf16[4,16,32,32]{3,2,1,0}, bf16[16]{0}, bf16[16]{0}) batch-norm-training(bf16[4,16,32,32]{3,2,1,0} %add.14, f32[16]{0} %p2.3, f32[16]{0} %p3.4), epsilon=1e-05, feature_index=1, but mixed precision is disallowed.
This issue is NOT reproducible when using dtype=torch.float16 or using torch.cuda.amp.autocast without XLA.
To Reproduce
import os
os.environ["XLA_REGISTER_INSTALLED_PLUGINS"] = "1"
import torch
from torch import nn
from torch_xla.amp import autocast
import torch_xla.core.xla_model as xm
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(16)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
def main():
device = xm.xla_device()
model = SimpleModel().to(device)
inputs = torch.randn(4, 3, 32, 32).to(device)
with autocast(device, dtype=torch.bfloat16):
output = model(inputs)
xm.mark_step()
if __name__ == '__main__':
main()
Expected behavior
Above code should run without error.
Environment
Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
torch_xla version: 2.3.0
The text was updated successfully, but these errors were encountered:
🐛 Bug
The following runtime error is raised when using autocast
dtype=torch.bfloat16
with conv + batchnorm layer:This issue is NOT reproducible when using
dtype=torch.float16
or usingtorch.cuda.amp.autocast
without XLA.To Reproduce
Expected behavior
Above code should run without error.
Environment
The text was updated successfully, but these errors were encountered: